From fa7516ce300e4b1e472442392aa4783bf9528cbe Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 31 May 2024 17:36:16 +0100 Subject: [PATCH] Quantiles. --- ext/stats/quantile.go | 59 ++++++++++++++++++++++++++++++++++++++ ext/stats/quantile_test.go | 55 +++++++++++++++++++++++++++++++++++ ext/stats/stats.go | 8 ++++++ go.mod | 2 ++ go.sum | 2 ++ go.work.sum | 4 +++ 6 files changed, 130 insertions(+) create mode 100644 ext/stats/quantile.go create mode 100644 ext/stats/quantile_test.go create mode 100644 go.work.sum diff --git a/ext/stats/quantile.go b/ext/stats/quantile.go new file mode 100644 index 0000000..4b7481c --- /dev/null +++ b/ext/stats/quantile.go @@ -0,0 +1,59 @@ +package stats + +import ( + "math" + "slices" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" + "github.com/ncruces/sort/quick" +) + +const ( + median = iota + quant_cont + quant_disc +) + +func newQuantile(kind int) func() sqlite3.AggregateFunction { + return func() sqlite3.AggregateFunction { return &quantile{kind: kind} } +} + +type quantile struct { + kind int + pos float64 + list []float64 +} + +func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { + if a := arg[0]; a.NumericType() != sqlite3.NULL { + q.list = append(q.list, a.Float()) + } + if q.kind != median { + q.pos = arg[1].Float() + } +} + +func (q *quantile) Value(ctx sqlite3.Context) { + if q.list == nil { + return + } + if q.kind == median { + q.pos = 0.5 + } + if q.pos < 0 || q.pos > 1 { + ctx.ResultError(util.ErrorString("quantile: invalid pos")) + return + } + + i, f := math.Modf(q.pos * float64(len(q.list)-1)) + m0 := quick.Select(q.list, int(i)) + + if q.kind == quant_disc { + ctx.ResultFloat(m0) + return + } + + m1 := slices.Min(q.list[int(i)+1:]) + ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0))) +} diff --git a/ext/stats/quantile_test.go b/ext/stats/quantile_test.go new file mode 100644 index 0000000..ded99aa --- /dev/null +++ b/ext/stats/quantile_test.go @@ -0,0 +1,55 @@ +package stats_test + +import ( + "testing" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/stats" + _ "github.com/ncruces/go-sqlite3/tests/testcfg" +) + +func TestRegister_quantile(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stats.Register(db) + + err = db.Exec(`CREATE TABLE data (x)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(` + SELECT + median(x), + quantile_disc(x, 0.5), + quantile_cont(x, 0.3) + FROM data`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 10 { + t.Errorf("got %v, want 10", got) + } + if got := stmt.ColumnFloat(1); got != 7 { + t.Errorf("got %v, want 7", got) + } + if got := stmt.ColumnFloat(2); got != 6.699999999999999 { + t.Errorf("got %v, want 6.7", got) + } + } +} diff --git a/ext/stats/stats.go b/ext/stats/stats.go index c1d0e5a..9096fc1 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -18,6 +18,9 @@ // - regr_slope: slope of the least-squares-fit linear equation // - regr_intercept: y-intercept of the least-squares-fit linear equation // - regr_json: all regr stats in a JSON object +// - median: median value +// - quantile_cont: continuous quantile +// - quantile_disc: discrete quantile // // These join the [Built-in Aggregate Functions]: // - count: count rows/values @@ -27,9 +30,11 @@ // - max: maximum value // // See: [ANSI SQL Aggregate Functions] +// See: [DuckDB Aggregate Functions] // // [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html // [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html +// [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html package stats import "github.com/ncruces/go-sqlite3" @@ -54,6 +59,9 @@ func Register(db *sqlite3.Conn) { db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)) db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)) db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)) + db.CreateWindowFunction("median", 1, flags, newQuantile(median)) + db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont)) + db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc)) } const ( diff --git a/go.mod b/go.mod index 4628622..622aa4d 100644 --- a/go.mod +++ b/go.mod @@ -13,4 +13,6 @@ require ( lukechampine.com/adiantum v1.1.1 ) +require github.com/ncruces/sort v0.1.2 + retract v0.4.0 // tagged from the wrong branch diff --git a/go.sum b/go.sum index a40d156..e4df547 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M= github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= +github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U= +github.com/ncruces/sort v0.1.2/go.mod h1:vEJUTBJtebIuCMmXD18GKo5GJGhsay+xZFOoBEIXFmE= github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE= github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE= github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc= diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..0349376 --- /dev/null +++ b/go.work.sum @@ -0,0 +1,4 @@ +golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= +golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=