From 8a3d45493562c57327d140768a3a371323825242 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sun, 2 Jun 2024 10:33:06 +0100 Subject: [PATCH] More tests. --- ext/stats/TODO.md | 10 +++++- ext/stats/quantile.go | 4 +-- ext/stats/quantile_test.go | 68 +++++++++++++++++++++++++++++++++++--- ext/stats/stats.go | 7 ++-- 4 files changed, 77 insertions(+), 12 deletions(-) diff --git a/ext/stats/TODO.md b/ext/stats/TODO.md index b8c80f9..f48827b 100644 --- a/ext/stats/TODO.md +++ b/ext/stats/TODO.md @@ -44,4 +44,12 @@ https://sqlite.org/lang_aggfunc.html - [ ] `PERCENTILE_CONT(percentile) OVER window` - [ ] `PERCENTILE_DISC(percentile) OVER window` -https://sqlite.org/windowfunctions.html#builtins \ No newline at end of file +https://sqlite.org/windowfunctions.html#builtins + +## Additional aggregates + +- [X] `MEDIAN(expression)` +- [X] `QUANTILE_CONT(expression, quantile)` +- [X] `QUANTILE_DISC(expression, quantile)` + +https://duckdb.org/docs/sql/aggregates.html \ No newline at end of file diff --git a/ext/stats/quantile.go b/ext/stats/quantile.go index 4b7481c..77d4217 100644 --- a/ext/stats/quantile.go +++ b/ext/stats/quantile.go @@ -35,7 +35,7 @@ func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { } func (q *quantile) Value(ctx sqlite3.Context) { - if q.list == nil { + if len(q.list) == 0 { return } if q.kind == median { @@ -49,7 +49,7 @@ func (q *quantile) Value(ctx sqlite3.Context) { i, f := math.Modf(q.pos * float64(len(q.list)-1)) m0 := quick.Select(q.list, int(i)) - if q.kind == quant_disc { + if f == 0 || q.kind == quant_disc { ctx.ResultFloat(m0) return } diff --git a/ext/stats/quantile_test.go b/ext/stats/quantile_test.go index ded99aa..8870b5f 100644 --- a/ext/stats/quantile_test.go +++ b/ext/stats/quantile_test.go @@ -34,13 +34,11 @@ func TestRegister_quantile(t *testing.T) { SELECT median(x), quantile_disc(x, 0.5), - quantile_cont(x, 0.3) + quantile_cont(x, 0.25) FROM data`) if err != nil { t.Fatal(err) } - defer stmt.Close() - if stmt.Step() { if got := stmt.ColumnFloat(0); got != 10 { t.Errorf("got %v, want 10", got) @@ -48,8 +46,68 @@ func TestRegister_quantile(t *testing.T) { if got := stmt.ColumnFloat(1); got != 7 { t.Errorf("got %v, want 7", got) } - if got := stmt.ColumnFloat(2); got != 6.699999999999999 { - t.Errorf("got %v, want 6.7", got) + if got := stmt.ColumnFloat(2); got != 6.25 { + t.Errorf("got %v, want 6.25", got) } } + stmt.Close() + + stmt, _, err = db.Prepare(` + SELECT + median(x), + quantile_disc(x, 0.5), + quantile_cont(x, 0.25) + FROM data + WHERE x < 5`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 4 { + t.Errorf("got %v, want 4", got) + } + if got := stmt.ColumnFloat(1); got != 4 { + t.Errorf("got %v, want 4", got) + } + if got := stmt.ColumnFloat(2); got != 4 { + t.Errorf("got %v, want 4", got) + } + } + stmt.Close() + + stmt, _, err = db.Prepare(` + SELECT + median(x), + quantile_disc(x, 0.5), + quantile_cont(x, 0.25) + FROM data + WHERE x < 0`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.NULL { + t.Error("want NULL") + } + if got := stmt.ColumnType(1); got != sqlite3.NULL { + t.Error("want NULL") + } + if got := stmt.ColumnType(2); got != sqlite3.NULL { + t.Error("want NULL") + } + } + stmt.Close() + + stmt, _, err = db.Prepare(` + SELECT + quantile_disc(x, -2), + quantile_cont(x, +2) + FROM data`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + t.Fatal("want error") + } + stmt.Close() } diff --git a/ext/stats/stats.go b/ext/stats/stats.go index 9096fc1..d86684f 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -18,9 +18,9 @@ // - regr_slope: slope of the least-squares-fit linear equation // - regr_intercept: y-intercept of the least-squares-fit linear equation // - regr_json: all regr stats in a JSON object -// - median: median value -// - quantile_cont: continuous quantile // - quantile_disc: discrete quantile +// - quantile_cont: continuous quantile +// - median: median value // // These join the [Built-in Aggregate Functions]: // - count: count rows/values @@ -29,8 +29,7 @@ // - min: minimum value // - max: maximum value // -// See: [ANSI SQL Aggregate Functions] -// See: [DuckDB Aggregate Functions] +// See: [ANSI SQL Aggregate Functions], [DuckDB Aggregate Functions] // // [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html // [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html