From c78d00dca0f0075d83b83e558c4122051999cf95 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 16 Oct 2024 14:00:22 +0100 Subject: [PATCH] Better percentile compatibility. --- const.go | 1 + ext/stats/percentile.go | 18 +++++++++++------- ext/stats/percentile_test.go | 18 +++++++++++++++--- ext/stats/stats.go | 8 +++++--- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/const.go b/const.go index e4c7d72..6b9665f 100644 --- a/const.go +++ b/const.go @@ -177,6 +177,7 @@ const ( DETERMINISTIC FunctionFlag = 0x000000800 DIRECTONLY FunctionFlag = 0x000080000 INNOCUOUS FunctionFlag = 0x000200000 + SELFORDER1 FunctionFlag = 0x002000000 // SUBTYPE FunctionFlag = 0x000100000 // RESULT_SUBTYPE FunctionFlag = 0x001000000 ) diff --git a/ext/stats/percentile.go b/ext/stats/percentile.go index c60ca22..68646f8 100644 --- a/ext/stats/percentile.go +++ b/ext/stats/percentile.go @@ -13,6 +13,7 @@ import ( const ( median = iota + percentile_100 percentile_cont percentile_disc ) @@ -54,13 +55,13 @@ func (q *percentile) Value(ctx sqlite3.Context) { floats []float64 ) if q.kind == median { - float, err = getPercentile(q.nums, 0.5, false) + float, err = getPercentile(q.nums, 0.5, q.kind) ctx.ResultFloat(float) } else if err = json.Unmarshal(q.arg1, &float); err == nil { - float, err = getPercentile(q.nums, float, q.kind == percentile_disc) + float, err = getPercentile(q.nums, float, q.kind) ctx.ResultFloat(float) } else if err = json.Unmarshal(q.arg1, &floats); err == nil { - err = getPercentiles(q.nums, floats, q.kind == percentile_disc) + err = getPercentiles(q.nums, floats, q.kind) ctx.ResultJSON(floats) } if err != nil { @@ -68,7 +69,10 @@ func (q *percentile) Value(ctx sqlite3.Context) { } } -func getPercentile(nums []float64, pos float64, disc bool) (float64, error) { +func getPercentile(nums []float64, pos float64, kind int) (float64, error) { + if kind == percentile_100 { + pos = pos / 100 + } if pos < 0 || pos > 1 { return 0, util.ErrorString("invalid pos") } @@ -76,7 +80,7 @@ func getPercentile(nums []float64, pos float64, disc bool) (float64, error) { i, f := math.Modf(pos * float64(len(nums)-1)) m0 := quick.Select(nums, int(i)) - if f == 0 || disc { + if f == 0 || kind == percentile_disc { return m0, nil } @@ -84,9 +88,9 @@ func getPercentile(nums []float64, pos float64, disc bool) (float64, error) { return math.FMA(f, m1, math.FMA(-f, m0, m0)), nil } -func getPercentiles(nums []float64, pos []float64, disc bool) error { +func getPercentiles(nums []float64, pos []float64, kind int) error { for i := range pos { - v, err := getPercentile(nums, pos[i], disc) + v, err := getPercentile(nums, pos[i], kind) if err != nil { return err } diff --git a/ext/stats/percentile_test.go b/ext/stats/percentile_test.go index 7abf40c..41f2e8b 100644 --- a/ext/stats/percentile_test.go +++ b/ext/stats/percentile_test.go @@ -31,6 +31,7 @@ func TestRegister_percentile(t *testing.T) { stmt, _, err := db.Prepare(` SELECT median(x), + percentile(x, 50), percentile_disc(x, 0.5), percentile_cont(x, '[0.25, 0.5, 0.75]') FROM data`) @@ -41,11 +42,14 @@ func TestRegister_percentile(t *testing.T) { if got := stmt.ColumnFloat(0); got != 10 { t.Errorf("got %v, want 10", got) } - if got := stmt.ColumnFloat(1); got != 7 { + if got := stmt.ColumnFloat(1); got != 10 { + t.Errorf("got %v, want 10", got) + } + if got := stmt.ColumnFloat(2); got != 7 { t.Errorf("got %v, want 7", got) } var got []float64 - if err := stmt.ColumnJSON(2, &got); err != nil { + if err := stmt.ColumnJSON(3, &got); err != nil { t.Error(err) } if !slices.Equal(got, []float64{6.25, 10, 13.75}) { @@ -57,6 +61,7 @@ func TestRegister_percentile(t *testing.T) { stmt, _, err = db.Prepare(` SELECT median(x), + percentile(x, 50), percentile_disc(x, 0.5), percentile_cont(x, '[0.25, 0.5, 0.75]') FROM data @@ -71,8 +76,11 @@ func TestRegister_percentile(t *testing.T) { if got := stmt.ColumnFloat(1); got != 4 { t.Errorf("got %v, want 4", got) } + if got := stmt.ColumnFloat(2); got != 4 { + t.Errorf("got %v, want 4", got) + } var got []float64 - if err := stmt.ColumnJSON(2, &got); err != nil { + if err := stmt.ColumnJSON(3, &got); err != nil { t.Error(err) } if !slices.Equal(got, []float64{4, 4, 4}) { @@ -84,6 +92,7 @@ func TestRegister_percentile(t *testing.T) { stmt, _, err = db.Prepare(` SELECT median(x), + percentile(x, 50), percentile_disc(x, 0.5), percentile_cont(x, '[0.25, 0.5, 0.75]') FROM data @@ -101,6 +110,9 @@ func TestRegister_percentile(t *testing.T) { if got := stmt.ColumnType(2); got != sqlite3.NULL { t.Error("want NULL") } + if got := stmt.ColumnType(3); got != sqlite3.NULL { + t.Error("want NULL") + } } stmt.Close() diff --git a/ext/stats/stats.go b/ext/stats/stats.go index cddbffa..4edfee3 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -53,6 +53,7 @@ import ( // Register registers statistics functions. func Register(db *sqlite3.Conn) error { const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS + const order = sqlite3.SELFORDER1 | flags return errors.Join( db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)), db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)), @@ -71,9 +72,10 @@ func Register(db *sqlite3.Conn) error { db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)), db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)), db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)), - db.CreateWindowFunction("median", 1, flags, newPercentile(median)), - db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont)), - db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc)), + db.CreateWindowFunction("median", 1, order, newPercentile(median)), + db.CreateWindowFunction("percentile", 2, order, newPercentile(percentile_100)), + db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)), + db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)), db.CreateWindowFunction("every", 1, flags, newBoolean(every)), db.CreateWindowFunction("some", 1, flags, newBoolean(some))) }