From 312d3b58f21c9a97c55f009b3b5ae342f900c009 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 31 Aug 2023 16:30:52 +0100 Subject: [PATCH] Statistics functions. --- ext/stats/stats.go | 66 ++++++++++++++++++++++++++++++++ ext/stats/stats_test.go | 79 +++++++++++++++++++++++++++++++++++++++ ext/stats/welford.go | 64 +++++++++++++++++++++++++++++++ ext/stats/welford_test.go | 38 +++++++++++++++++++ ext/unicode/unicode.go | 17 +++++---- func_win_test.go | 2 +- 6 files changed, 258 insertions(+), 8 deletions(-) create mode 100644 ext/stats/stats.go create mode 100644 ext/stats/stats_test.go create mode 100644 ext/stats/welford.go create mode 100644 ext/stats/welford_test.go diff --git a/ext/stats/stats.go b/ext/stats/stats.go new file mode 100644 index 0000000..6d76b09 --- /dev/null +++ b/ext/stats/stats.go @@ -0,0 +1,66 @@ +// Package stats provides aggregate functions for statistics. +// +// Functions: +// - var_samp: sample variance +// - var_pop: population variance +// - stddev_samp: sample standard deviation +// - stddev_pop: population standard deviation +// +// See: [ANSI SQL Aggregate Functions] +// +// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html +package stats + +import "github.com/ncruces/go-sqlite3" + +// Register registers statistics functions. +func Register(db *sqlite3.Conn) { + flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS + db.CreateWindowFunction("var_pop", 1, flags, create(var_pop)) + db.CreateWindowFunction("var_samp", 1, flags, create(var_samp)) + db.CreateWindowFunction("stddev_pop", 1, flags, create(stddev_pop)) + db.CreateWindowFunction("stddev_samp", 1, flags, create(stddev_samp)) +} + +const ( + var_pop = iota + var_samp + stddev_pop + stddev_samp +) + +func create(kind int) func() sqlite3.AggregateFunction { + return func() sqlite3.AggregateFunction { return &state{kind: kind} } +} + +type state struct { + kind int + welford +} + +func (f *state) Value(ctx sqlite3.Context) { + var r float64 + switch f.kind { + case var_pop: + r = f.var_pop() + case var_samp: + r = f.var_samp() + case stddev_pop: + r = f.stddev_pop() + case stddev_samp: + r = f.stddev_samp() + } + ctx.ResultFloat(r) +} + +func (f *state) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { + if a := arg[0]; a.Type() != sqlite3.NULL { + f.enqueue(a.Float()) + } +} + +func (f *state) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { + if a := arg[0]; a.Type() != sqlite3.NULL { + f.dequeue(a.Float()) + } +} diff --git a/ext/stats/stats_test.go b/ext/stats/stats_test.go new file mode 100644 index 0000000..35779ed --- /dev/null +++ b/ext/stats/stats_test.go @@ -0,0 +1,79 @@ +package stats + +import ( + "math" + "testing" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" +) + +func TestRegister(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + Register(db) + + err = db.Exec(`CREATE TABLE IF NOT EXISTS data (col)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO data (col) VALUES (4), (7.0), ('13'), (NULL), (16)`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT + sum(col), avg(col), + var_samp(col), var_pop(col), + stddev_samp(col), stddev_pop(col) FROM data`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 40 { + t.Errorf("got %v, want 40", got) + } + if got := stmt.ColumnFloat(1); got != 10 { + t.Errorf("got %v, want 10", got) + } + if got := stmt.ColumnFloat(2); got != 30 { + t.Errorf("got %v, want 30", got) + } + if got := stmt.ColumnFloat(3); got != 22.5 { + t.Errorf("got %v, want 22.5", got) + } + if got := stmt.ColumnFloat(4); got != math.Sqrt(30) { + t.Errorf("got %v, want √30", got) + } + if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) { + t.Errorf("got %v, want √22.5", got) + } + } + + { + stmt, _, err := db.Prepare(`SELECT var_samp(col) OVER (ROWS 1 PRECEDING) FROM data`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + want := [...]float64{0, 4.5, 18, 0, 0} + for i := 0; stmt.Step(); i++ { + if got := stmt.ColumnFloat(0); got != want[i] { + t.Errorf("got %v, want %v", got, want[i]) + } + if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) { + t.Errorf("got %v, want %v", got, want[i]) + } + } + } +} diff --git a/ext/stats/welford.go b/ext/stats/welford.go new file mode 100644 index 0000000..49c1bda --- /dev/null +++ b/ext/stats/welford.go @@ -0,0 +1,64 @@ +package stats + +import "math" + +// Welford's algorithm with Kahan summation: +// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm +// https://en.wikipedia.org/wiki/Kahan_summation_algorithm + +type welford struct { + m1, m2 kahan + n uint64 +} + +func (w welford) average() float64 { + return w.m1.hi +} + +func (w welford) var_pop() float64 { + return w.m2.hi / float64(w.n) +} + +func (w welford) var_samp() float64 { + return w.m2.hi / float64(w.n-1) +} + +func (w welford) stddev_pop() float64 { + return math.Sqrt(w.var_pop()) +} + +func (w welford) stddev_samp() float64 { + return math.Sqrt(w.var_samp()) +} + +func (w *welford) enqueue(x float64) { + w.n++ + d1 := x - w.m1.hi + w.m1.add(d1 / float64(w.n)) + d2 := x - w.m1.hi + w.m2.add(d1 * d2) +} + +func (w *welford) dequeue(x float64) { + w.n-- + d1 := x - w.m1.hi + w.m1.sub(d1 / float64(w.n)) + d2 := x - w.m1.hi + w.m2.sub(d1 * d2) +} + +type kahan struct{ hi, lo float64 } + +func (k *kahan) add(x float64) { + y := k.lo + x + t := k.hi + y + k.lo = y - (t - k.hi) + k.hi = t +} + +func (k *kahan) sub(x float64) { + y := k.lo - x + t := k.hi + y + k.lo = y - (t - k.hi) + k.hi = t +} diff --git a/ext/stats/welford_test.go b/ext/stats/welford_test.go new file mode 100644 index 0000000..bf4832d --- /dev/null +++ b/ext/stats/welford_test.go @@ -0,0 +1,38 @@ +package stats + +import ( + "math" + "testing" +) + +func Test_welford(t *testing.T) { + var s1, s2 welford + + s1.enqueue(4) + s1.enqueue(7) + s1.enqueue(13) + s1.enqueue(16) + if got := s1.average(); got != 10 { + t.Errorf("got %v, want 10", got) + } + if got := s1.var_samp(); got != 30 { + t.Errorf("got %v, want 30", got) + } + if got := s1.var_pop(); got != 22.5 { + t.Errorf("got %v, want 22.5", got) + } + if got := s1.stddev_samp(); got != math.Sqrt(30) { + t.Errorf("got %v, want √30", got) + } + if got := s1.stddev_pop(); got != math.Sqrt(22.5) { + t.Errorf("got %v, want √22.5", got) + } + + s1.dequeue(4) + s2.enqueue(7) + s2.enqueue(13) + s2.enqueue(16) + if s1 != s2 { + t.Errorf("got %v, want %v", s1, s2) + } +} diff --git a/ext/unicode/unicode.go b/ext/unicode/unicode.go index c47ec5f..809031b 100644 --- a/ext/unicode/unicode.go +++ b/ext/unicode/unicode.go @@ -1,17 +1,19 @@ // Package unicode provides an alternative to the SQLite ICU extension. // -// Provides Unicode aware: -// - upper and lower functions, +// Like the [ICU extension], it provides Unicode aware: +// - upper() and lower() functions, // - LIKE and REGEXP operators, // - collation sequences. // -// This package is not 100% compatible with the ICU extension: -// - upper and lower use [strings.ToUpper], [strings.ToLower] and [cases]; +// The implementation is not 100% compatible with the [ICU extension]: +// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases]; // - the LIKE operator follows [strings.EqualFold] rules; // - the REGEXP operator uses Go [regex/syntax]; // - collation sequences use [collate]. // // Expect subtle differences (e.g.) in the handling of Turkish case folding. +// +// [ICU extension]: https://sqlite.org/src/dir/ext/icu package unicode import ( @@ -45,7 +47,7 @@ func Register(db *sqlite3.Conn) { return } - err := RegisterCollation(db, name, arg[0].Text()) + err := RegisterCollation(db, arg[0].Text(), name) if err != nil { ctx.ResultError(err) return @@ -53,8 +55,9 @@ func Register(db *sqlite3.Conn) { }) } -func RegisterCollation(db *sqlite3.Conn, name, lang string) error { - tag, err := language.Parse(lang) +// RegisterCollation registers a Unicode collation sequence for a database connection. +func RegisterCollation(db *sqlite3.Conn, locale, name string) error { + tag, err := language.Parse(locale) if err != nil { return err } diff --git a/func_win_test.go b/func_win_test.go index 5832161..bbe2e19 100644 --- a/func_win_test.go +++ b/func_win_test.go @@ -26,7 +26,7 @@ func ExampleConn_CreateWindowFunction() { log.Fatal(err) } - err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter) + err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter) if err != nil { log.Fatal(err) }