From dbf764aaf4875c988f9b49c210d525ee57363379 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 6 Jun 2024 16:41:20 +0100 Subject: [PATCH] Boolean aggregates. --- ext/stats/TODO.md | 6 ++-- ext/stats/boolean.go | 46 ++++++++++++++++++++++++ ext/stats/boolean_test.go | 74 +++++++++++++++++++++++++++++++++++++++ ext/stats/stats.go | 11 ++++++ ext/stats/stats_test.go | 58 ++++++++++++++---------------- stmt.go | 4 +-- value.go | 4 +-- 7 files changed, 163 insertions(+), 40 deletions(-) create mode 100644 ext/stats/boolean.go create mode 100644 ext/stats/boolean_test.go diff --git a/ext/stats/TODO.md b/ext/stats/TODO.md index 9c9c503..f7e2a3c 100644 --- a/ext/stats/TODO.md +++ b/ext/stats/TODO.md @@ -48,10 +48,8 @@ https://sqlite.org/windowfunctions.html#builtins ## Boolean aggregates -- [ ] `ALL(boolean)` -- [ ] `ANY(boolean)` -- [ ] `EVERY(boolean)` -- [ ] `SOME(boolean)` +- [X] `EVERY(boolean)` +- [X] `SOME(boolean)` ## Additional aggregates diff --git a/ext/stats/boolean.go b/ext/stats/boolean.go new file mode 100644 index 0000000..ba0ed69 --- /dev/null +++ b/ext/stats/boolean.go @@ -0,0 +1,46 @@ +package stats + +import "github.com/ncruces/go-sqlite3" + +const ( + every = iota + some +) + +func newBoolean(kind int) func() sqlite3.AggregateFunction { + return func() sqlite3.AggregateFunction { return &boolean{kind: kind} } +} + +type boolean struct { + count int + total int + kind int +} + +func (b *boolean) Value(ctx sqlite3.Context) { + if b.kind == every { + ctx.ResultBool(b.count == b.total) + } else { + ctx.ResultBool(b.count > 0) + } +} + +func (b *boolean) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { + if arg[0].Type() == sqlite3.NULL { + return + } + if arg[0].Bool() { + b.count++ + } + b.total++ +} + +func (b *boolean) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { + if arg[0].Type() == sqlite3.NULL { + return + } + if arg[0].Bool() { + b.count-- + } + b.total-- +} diff --git a/ext/stats/boolean_test.go b/ext/stats/boolean_test.go new file mode 100644 index 0000000..31959a1 --- /dev/null +++ b/ext/stats/boolean_test.go @@ -0,0 +1,74 @@ +package stats_test + +import ( + "testing" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/stats" + _ "github.com/ncruces/go-sqlite3/internal/testcfg" +) + +func TestRegister_boolean(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stats.Register(db) + + err = db.Exec(`CREATE TABLE data (x)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), (13), (NULL), (16), (3.14)`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(` + SELECT + every(x > 0), + every(x > 10), + some(x > 10), + some(x > 20) + FROM data`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnBool(0); got != true { + t.Errorf("got %v, want true", got) + } + if got := stmt.ColumnBool(1); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnBool(2); got != true { + t.Errorf("got %v, want true", got) + } + if got := stmt.ColumnBool(3); got != false { + t.Errorf("got %v, want false", got) + } + } + stmt.Close() + + stmt, _, err = db.Prepare(`SELECT every(x > 10) OVER (ROWS 1 PRECEDING) FROM data`) + if err != nil { + t.Fatal(err) + } + + want := [...]bool{false, false, false, true, true, false} + for i := 0; stmt.Step(); i++ { + if got := stmt.ColumnBool(0); got != want[i] { + t.Errorf("got %v, want %v", got, want[i]) + } + if got := stmt.ColumnType(0); got != sqlite3.INTEGER { + t.Errorf("got %v, want INTEGER", got) + } + } + stmt.Close() +} diff --git a/ext/stats/stats.go b/ext/stats/stats.go index d86684f..d32b323 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -21,6 +21,8 @@ // - quantile_disc: discrete quantile // - quantile_cont: continuous quantile // - median: median value +// - every: boolean and +// - some: boolean or // // These join the [Built-in Aggregate Functions]: // - count: count rows/values @@ -29,9 +31,16 @@ // - min: minimum value // - max: maximum value // +// And the [Built-in Window Functions]: +// - rank: rank of the current row with gaps +// - dense_rank: rank of the current row without gaps +// - percent_rank: relative rank of the row +// - cume_dist: cumulative distribution +// // See: [ANSI SQL Aggregate Functions], [DuckDB Aggregate Functions] // // [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html +// [Built-in Window Functions]: https://sqlite.org/windowfunctions.html#builtins // [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html // [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html package stats @@ -61,6 +70,8 @@ func Register(db *sqlite3.Conn) { db.CreateWindowFunction("median", 1, flags, newQuantile(median)) db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont)) db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc)) + db.CreateWindowFunction("every", 1, flags, newBoolean(every)) + db.CreateWindowFunction("some", 1, flags, newBoolean(some)) } const ( diff --git a/ext/stats/stats_test.go b/ext/stats/stats_test.go index 4a5343a..33bdd41 100644 --- a/ext/stats/stats_test.go +++ b/ext/stats/stats_test.go @@ -40,8 +40,6 @@ func TestRegister_variance(t *testing.T) { if err != nil { t.Fatal(err) } - defer stmt.Close() - if stmt.Step() { if got := stmt.ColumnFloat(0); got != 40 { t.Errorf("got %v, want 40", got) @@ -62,24 +60,23 @@ func TestRegister_variance(t *testing.T) { t.Errorf("got %v, want √22.5", got) } } + stmt.Close() - { - stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`) - if err != nil { - t.Fatal(err) + stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`) + if err != nil { + t.Fatal(err) + } + + want := [...]float64{0, 4.5, 18, 0, 0} + for i := 0; stmt.Step(); i++ { + if got := stmt.ColumnFloat(0); got != want[i] { + t.Errorf("got %v, want %v", got, want[i]) } - defer stmt.Close() - - want := [...]float64{0, 4.5, 18, 0, 0} - for i := 0; stmt.Step(); i++ { - if got := stmt.ColumnFloat(0); got != want[i] { - t.Errorf("got %v, want %v", got, want[i]) - } - if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) { - t.Errorf("got %v, want %v", got, want[i]) - } + if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) { + t.Errorf("got %v, want %v", got, want[i]) } } + stmt.Close() } func TestRegister_covariance(t *testing.T) { @@ -113,8 +110,6 @@ func TestRegister_covariance(t *testing.T) { if err != nil { t.Fatal(err) } - defer stmt.Close() - if stmt.Step() { if got := stmt.ColumnFloat(0); got != 0.9881049293224639 { t.Errorf("got %v, want 0.9881049293224639", got) @@ -159,24 +154,23 @@ func TestRegister_covariance(t *testing.T) { t.Errorf("got %v, want 5", got) } } + stmt.Close() - { - stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`) - if err != nil { - t.Fatal(err) + stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`) + if err != nil { + t.Fatal(err) + } + + want := [...]float64{0, 10, 30, 75, 22.5} + for i := 0; stmt.Step(); i++ { + if got := stmt.ColumnFloat(0); got != want[i] { + t.Errorf("got %v, want %v", got, want[i]) } - defer stmt.Close() - - want := [...]float64{0, 10, 30, 75, 22.5} - for i := 0; stmt.Step(); i++ { - if got := stmt.ColumnFloat(0); got != want[i] { - t.Errorf("got %v, want %v", got, want[i]) - } - if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) { - t.Errorf("got %v, want %v", got, want[i]) - } + if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) { + t.Errorf("got %v, want %v", got, want[i]) } } + stmt.Close() } func Benchmark_average(b *testing.B) { diff --git a/stmt.go b/stmt.go index 63c2085..ac40e38 100644 --- a/stmt.go +++ b/stmt.go @@ -441,12 +441,12 @@ func (s *Stmt) ColumnOriginName(col int) string { // ColumnBool returns the value of the result column as a bool. // The leftmost column of the result set has the index 0. // SQLite does not have a separate boolean storage class. -// Instead, boolean values are retrieved as integers, +// Instead, boolean values are retrieved as numbers, // with 0 converted to false and any other value to true. // // https://sqlite.org/c3ref/column_blob.html func (s *Stmt) ColumnBool(col int) bool { - return s.ColumnInt64(col) != 0 + return s.ColumnFloat(col) != 0 } // ColumnInt returns the value of the result column as an int. diff --git a/value.go b/value.go index 61d3cbf..d0edf21 100644 --- a/value.go +++ b/value.go @@ -68,12 +68,12 @@ func (v Value) NumericType() Datatype { // Bool returns the value as a bool. // SQLite does not have a separate boolean storage class. -// Instead, boolean values are retrieved as integers, +// Instead, boolean values are retrieved as numbers, // with 0 converted to false and any other value to true. // // https://sqlite.org/c3ref/value_blob.html func (v Value) Bool() bool { - return v.Int64() != 0 + return v.Float() != 0 } // Int returns the value as an int.