diff --git a/ext/stats/boolean.go b/ext/stats/boolean.go index ba0ed69..f30fa07 100644 --- a/ext/stats/boolean.go +++ b/ext/stats/boolean.go @@ -26,21 +26,21 @@ func (b *boolean) Value(ctx sqlite3.Context) { } func (b *boolean) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { - if arg[0].Type() == sqlite3.NULL { - return - } - if arg[0].Bool() { + a := arg[0] + if a.Bool() { b.count++ } - b.total++ + if a.Type() != sqlite3.NULL { + b.total++ + } } func (b *boolean) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { - if arg[0].Type() == sqlite3.NULL { - return - } - if arg[0].Bool() { + a := arg[0] + if a.Bool() { b.count-- } - b.total-- + if a.Type() != sqlite3.NULL { + b.total-- + } } diff --git a/ext/stats/percentile.go b/ext/stats/percentile.go index 68646f8..2999b56 100644 --- a/ext/stats/percentile.go +++ b/ext/stats/percentile.go @@ -40,8 +40,14 @@ func (q *percentile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { } func (q *percentile) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { - // Implementing inverse allows certain queries that don't really need it to succeed. - ctx.ResultError(util.ErrorString("percentile: may not be used as a window function")) + a := arg[0] + f := a.Float() + if f != 0.0 || a.NumericType() != sqlite3.NULL { + i := slices.Index(q.nums, f) + l := len(q.nums) - 1 + q.nums[i] = q.nums[l] + q.nums = q.nums[:l] + } } func (q *percentile) Value(ctx sqlite3.Context) { @@ -55,13 +61,13 @@ func (q *percentile) Value(ctx sqlite3.Context) { floats []float64 ) if q.kind == median { - float, err = getPercentile(q.nums, 0.5, q.kind) + float, err = q.at(0.5) ctx.ResultFloat(float) } else if err = json.Unmarshal(q.arg1, &float); err == nil { - float, err = getPercentile(q.nums, float, q.kind) + float, err = q.at(float) ctx.ResultFloat(float) } else if err = json.Unmarshal(q.arg1, &floats); err == nil { - err = getPercentiles(q.nums, floats, q.kind) + err = q.atMore(floats) ctx.ResultJSON(floats) } if err != nil { @@ -69,28 +75,28 @@ func (q *percentile) Value(ctx sqlite3.Context) { } } -func getPercentile(nums []float64, pos float64, kind int) (float64, error) { - if kind == percentile_100 { +func (q *percentile) at(pos float64) (float64, error) { + if q.kind == percentile_100 { pos = pos / 100 } if pos < 0 || pos > 1 { return 0, util.ErrorString("invalid pos") } - i, f := math.Modf(pos * float64(len(nums)-1)) - m0 := quick.Select(nums, int(i)) + i, f := math.Modf(pos * float64(len(q.nums)-1)) + m0 := quick.Select(q.nums, int(i)) - if f == 0 || kind == percentile_disc { + if f == 0 || q.kind == percentile_disc { return m0, nil } - m1 := slices.Min(nums[int(i)+1:]) + m1 := slices.Min(q.nums[int(i)+1:]) return math.FMA(f, m1, math.FMA(-f, m0, m0)), nil } -func getPercentiles(nums []float64, pos []float64, kind int) error { +func (q *percentile) atMore(pos []float64) error { for i := range pos { - v, err := getPercentile(nums, pos[i], kind) + v, err := q.at(pos[i]) if err != nil { return err } diff --git a/ext/stats/percentile_test.go b/ext/stats/percentile_test.go index 41f2e8b..985f10e 100644 --- a/ext/stats/percentile_test.go +++ b/ext/stats/percentile_test.go @@ -58,6 +58,40 @@ func TestRegister_percentile(t *testing.T) { } stmt.Close() + stmt, _, err = db.Prepare(` + SELECT + median(x) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) + FROM data`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 5.5 { + t.Errorf("got %v, want 5.5", got) + } + } + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 7 { + t.Errorf("got %v, want 7", got) + } + } + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 10 { + t.Errorf("got %v, want 10", got) + } + } + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 14.5 { + t.Errorf("got %v, want 14.5", got) + } + } + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 16 { + t.Errorf("got %v, want 16", got) + } + } + stmt.Close() + stmt, _, err = db.Prepare(` SELECT median(x),