mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Boolean aggregates.
This commit is contained in:
@@ -48,10 +48,8 @@ https://sqlite.org/windowfunctions.html#builtins
|
||||
|
||||
## Boolean aggregates
|
||||
|
||||
- [ ] `ALL(boolean)`
|
||||
- [ ] `ANY(boolean)`
|
||||
- [ ] `EVERY(boolean)`
|
||||
- [ ] `SOME(boolean)`
|
||||
- [X] `EVERY(boolean)`
|
||||
- [X] `SOME(boolean)`
|
||||
|
||||
## Additional aggregates
|
||||
|
||||
|
||||
46
ext/stats/boolean.go
Normal file
46
ext/stats/boolean.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package stats
|
||||
|
||||
import "github.com/ncruces/go-sqlite3"
|
||||
|
||||
const (
|
||||
every = iota
|
||||
some
|
||||
)
|
||||
|
||||
func newBoolean(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &boolean{kind: kind} }
|
||||
}
|
||||
|
||||
type boolean struct {
|
||||
count int
|
||||
total int
|
||||
kind int
|
||||
}
|
||||
|
||||
func (b *boolean) Value(ctx sqlite3.Context) {
|
||||
if b.kind == every {
|
||||
ctx.ResultBool(b.count == b.total)
|
||||
} else {
|
||||
ctx.ResultBool(b.count > 0)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *boolean) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if arg[0].Type() == sqlite3.NULL {
|
||||
return
|
||||
}
|
||||
if arg[0].Bool() {
|
||||
b.count++
|
||||
}
|
||||
b.total++
|
||||
}
|
||||
|
||||
func (b *boolean) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if arg[0].Type() == sqlite3.NULL {
|
||||
return
|
||||
}
|
||||
if arg[0].Bool() {
|
||||
b.count--
|
||||
}
|
||||
b.total--
|
||||
}
|
||||
74
ext/stats/boolean_test.go
Normal file
74
ext/stats/boolean_test.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/stats"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func TestRegister_boolean(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stats.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), (13), (NULL), (16), (3.14)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT
|
||||
every(x > 0),
|
||||
every(x > 10),
|
||||
some(x > 10),
|
||||
some(x > 20)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnBool(0); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(1); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(2); got != true {
|
||||
t.Errorf("got %v, want true", got)
|
||||
}
|
||||
if got := stmt.ColumnBool(3); got != false {
|
||||
t.Errorf("got %v, want false", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT every(x > 10) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := [...]bool{false, false, false, true, true, false}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnBool(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
|
||||
t.Errorf("got %v, want INTEGER", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
@@ -21,6 +21,8 @@
|
||||
// - quantile_disc: discrete quantile
|
||||
// - quantile_cont: continuous quantile
|
||||
// - median: median value
|
||||
// - every: boolean and
|
||||
// - some: boolean or
|
||||
//
|
||||
// These join the [Built-in Aggregate Functions]:
|
||||
// - count: count rows/values
|
||||
@@ -29,9 +31,16 @@
|
||||
// - min: minimum value
|
||||
// - max: maximum value
|
||||
//
|
||||
// And the [Built-in Window Functions]:
|
||||
// - rank: rank of the current row with gaps
|
||||
// - dense_rank: rank of the current row without gaps
|
||||
// - percent_rank: relative rank of the row
|
||||
// - cume_dist: cumulative distribution
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions], [DuckDB Aggregate Functions]
|
||||
//
|
||||
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
|
||||
// [Built-in Window Functions]: https://sqlite.org/windowfunctions.html#builtins
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
// [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html
|
||||
package stats
|
||||
@@ -61,6 +70,8 @@ func Register(db *sqlite3.Conn) {
|
||||
db.CreateWindowFunction("median", 1, flags, newQuantile(median))
|
||||
db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont))
|
||||
db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc))
|
||||
db.CreateWindowFunction("every", 1, flags, newBoolean(every))
|
||||
db.CreateWindowFunction("some", 1, flags, newBoolean(some))
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
@@ -40,8 +40,6 @@ func TestRegister_variance(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 40 {
|
||||
t.Errorf("got %v, want 40", got)
|
||||
@@ -62,24 +60,23 @@ func TestRegister_variance(t *testing.T) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := [...]float64{0, 4.5, 18, 0, 0}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 4.5, 18, 0, 0}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
func TestRegister_covariance(t *testing.T) {
|
||||
@@ -113,8 +110,6 @@ func TestRegister_covariance(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
|
||||
t.Errorf("got %v, want 0.9881049293224639", got)
|
||||
@@ -159,24 +154,23 @@ func TestRegister_covariance(t *testing.T) {
|
||||
t.Errorf("got %v, want 5", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
want := [...]float64{0, 10, 30, 75, 22.5}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 10, 30, 75, 22.5}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
func Benchmark_average(b *testing.B) {
|
||||
|
||||
4
stmt.go
4
stmt.go
@@ -441,12 +441,12 @@ func (s *Stmt) ColumnOriginName(col int) string {
|
||||
// ColumnBool returns the value of the result column as a bool.
|
||||
// The leftmost column of the result set has the index 0.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are retrieved as integers,
|
||||
// Instead, boolean values are retrieved as numbers,
|
||||
// with 0 converted to false and any other value to true.
|
||||
//
|
||||
// https://sqlite.org/c3ref/column_blob.html
|
||||
func (s *Stmt) ColumnBool(col int) bool {
|
||||
return s.ColumnInt64(col) != 0
|
||||
return s.ColumnFloat(col) != 0
|
||||
}
|
||||
|
||||
// ColumnInt returns the value of the result column as an int.
|
||||
|
||||
4
value.go
4
value.go
@@ -68,12 +68,12 @@ func (v Value) NumericType() Datatype {
|
||||
|
||||
// Bool returns the value as a bool.
|
||||
// SQLite does not have a separate boolean storage class.
|
||||
// Instead, boolean values are retrieved as integers,
|
||||
// Instead, boolean values are retrieved as numbers,
|
||||
// with 0 converted to false and any other value to true.
|
||||
//
|
||||
// https://sqlite.org/c3ref/value_blob.html
|
||||
func (v Value) Bool() bool {
|
||||
return v.Int64() != 0
|
||||
return v.Float() != 0
|
||||
}
|
||||
|
||||
// Int returns the value as an int.
|
||||
|
||||
Reference in New Issue
Block a user