From 746a84965e6bd6808feac32cb108524f1c490407 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 1 Sep 2023 02:26:30 +0100 Subject: [PATCH] Covariance. --- ext/stats/stats.go | 78 +++++++++++++++++++++++++++++---------- ext/stats/stats_test.go | 74 +++++++++++++++++++++++++++++++++---- ext/stats/welford.go | 43 ++++++++++++++++++--- ext/stats/welford_test.go | 31 ++++++++++++++++ func.go | 2 +- 5 files changed, 194 insertions(+), 34 deletions(-) diff --git a/ext/stats/stats.go b/ext/stats/stats.go index 6d76b09..87e822b 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -1,10 +1,12 @@ // Package stats provides aggregate functions for statistics. // // Functions: -// - var_samp: sample variance -// - var_pop: population variance -// - stddev_samp: sample standard deviation // - stddev_pop: population standard deviation +// - stddev_samp: sample standard deviation +// - var_pop: population variance +// - var_samp: sample variance +// - covar_pop: population covariance +// - covar_samp: sample covariance // // See: [ANSI SQL Aggregate Functions] // @@ -16,10 +18,12 @@ import "github.com/ncruces/go-sqlite3" // Register registers statistics functions. func Register(db *sqlite3.Conn) { flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS - db.CreateWindowFunction("var_pop", 1, flags, create(var_pop)) - db.CreateWindowFunction("var_samp", 1, flags, create(var_samp)) - db.CreateWindowFunction("stddev_pop", 1, flags, create(stddev_pop)) - db.CreateWindowFunction("stddev_samp", 1, flags, create(stddev_samp)) + db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)) + db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)) + db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop)) + db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)) + db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)) + db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)) } const ( @@ -29,38 +33,72 @@ const ( stddev_samp ) -func create(kind int) func() sqlite3.AggregateFunction { - return func() sqlite3.AggregateFunction { return &state{kind: kind} } +func newVariance(kind int) func() sqlite3.AggregateFunction { + return func() sqlite3.AggregateFunction { return &variance{kind: kind} } } -type state struct { +type variance struct { kind int welford } -func (f *state) Value(ctx sqlite3.Context) { +func (fn *variance) Value(ctx sqlite3.Context) { var r float64 - switch f.kind { + switch fn.kind { case var_pop: - r = f.var_pop() + r = fn.var_pop() case var_samp: - r = f.var_samp() + r = fn.var_samp() case stddev_pop: - r = f.stddev_pop() + r = fn.stddev_pop() case stddev_samp: - r = f.stddev_samp() + r = fn.stddev_samp() } ctx.ResultFloat(r) } -func (f *state) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { +func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { if a := arg[0]; a.Type() != sqlite3.NULL { - f.enqueue(a.Float()) + fn.enqueue(a.Float()) } } -func (f *state) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { +func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { if a := arg[0]; a.Type() != sqlite3.NULL { - f.dequeue(a.Float()) + fn.dequeue(a.Float()) + } +} + +func newCovariance(kind int) func() sqlite3.AggregateFunction { + return func() sqlite3.AggregateFunction { return &covariance{kind: kind} } +} + +type covariance struct { + kind int + welford2 +} + +func (fn *covariance) Value(ctx sqlite3.Context) { + var r float64 + switch fn.kind { + case var_pop: + r = fn.covar_pop() + case var_samp: + r = fn.covar_samp() + } + ctx.ResultFloat(r) +} + +func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { + a, b := arg[0], arg[1] + if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL { + fn.enqueue(a.Float(), b.Float()) + } +} + +func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { + a, b := arg[0], arg[1] + if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL { + fn.dequeue(a.Float(), b.Float()) } } diff --git a/ext/stats/stats_test.go b/ext/stats/stats_test.go index 35779ed..7395f29 100644 --- a/ext/stats/stats_test.go +++ b/ext/stats/stats_test.go @@ -8,7 +8,7 @@ import ( _ "github.com/ncruces/go-sqlite3/embed" ) -func TestRegister(t *testing.T) { +func TestRegister_variance(t *testing.T) { t.Parallel() db, err := sqlite3.Open(":memory:") @@ -19,20 +19,22 @@ func TestRegister(t *testing.T) { Register(db) - err = db.Exec(`CREATE TABLE IF NOT EXISTS data (col)`) + err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`) if err != nil { t.Fatal(err) } - err = db.Exec(`INSERT INTO data (col) VALUES (4), (7.0), ('13'), (NULL), (16)`) + err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`) if err != nil { t.Fatal(err) } - stmt, _, err := db.Prepare(`SELECT - sum(col), avg(col), - var_samp(col), var_pop(col), - stddev_samp(col), stddev_pop(col) FROM data`) + stmt, _, err := db.Prepare(` + SELECT + sum(x), avg(x), + var_samp(x), var_pop(x), + stddev_samp(x), stddev_pop(x) + FROM data`) if err != nil { t.Fatal(err) } @@ -60,7 +62,7 @@ func TestRegister(t *testing.T) { } { - stmt, _, err := db.Prepare(`SELECT var_samp(col) OVER (ROWS 1 PRECEDING) FROM data`) + stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`) if err != nil { t.Fatal(err) } @@ -77,3 +79,59 @@ func TestRegister(t *testing.T) { } } } + +func TestRegister_covariance(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + Register(db) + + err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT + covar_samp(x, y), covar_pop(x, y) FROM data`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 21.25 { + t.Errorf("got %v, want 21.25", got) + } + if got := stmt.ColumnFloat(1); got != 17 { + t.Errorf("got %v, want 17", got) + } + } + + { + stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + want := [...]float64{0, 10, 30, 75, 22.5} + for i := 0; stmt.Step(); i++ { + if got := stmt.ColumnFloat(0); got != want[i] { + t.Errorf("got %v, want %v", got, want[i]) + } + if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) { + t.Errorf("got %v, want %v", got, want[i]) + } + } + } +} diff --git a/ext/stats/welford.go b/ext/stats/welford.go index 49c1bda..a6afab0 100644 --- a/ext/stats/welford.go +++ b/ext/stats/welford.go @@ -20,7 +20,7 @@ func (w welford) var_pop() float64 { } func (w welford) var_samp() float64 { - return w.m2.hi / float64(w.n-1) + return w.m2.hi / float64(w.n-1) // Bessel's correction } func (w welford) stddev_pop() float64 { @@ -33,20 +33,53 @@ func (w welford) stddev_samp() float64 { func (w *welford) enqueue(x float64) { w.n++ - d1 := x - w.m1.hi + d1 := x - w.m1.hi - w.m1.lo w.m1.add(d1 / float64(w.n)) - d2 := x - w.m1.hi + d2 := x - w.m1.hi - w.m1.lo w.m2.add(d1 * d2) } func (w *welford) dequeue(x float64) { w.n-- - d1 := x - w.m1.hi + d1 := x - w.m1.hi - w.m1.lo w.m1.sub(d1 / float64(w.n)) - d2 := x - w.m1.hi + d2 := x - w.m1.hi - w.m1.lo w.m2.sub(d1 * d2) } +type welford2 struct { + x, y, c kahan + n uint64 +} + +func (w welford2) covar_pop() float64 { + return w.c.hi / float64(w.n) +} + +func (w welford2) covar_samp() float64 { + return w.c.hi / float64(w.n-1) // Bessel's correction +} + +func (w *welford2) enqueue(x, y float64) { + w.n++ + dx := x - w.x.hi - w.x.lo + dy := y - w.y.hi - w.y.lo + w.x.add(dx / float64(w.n)) + w.y.add(dy / float64(w.n)) + d2 := y - w.y.hi - w.y.lo + w.c.add(dx * d2) +} + +func (w *welford2) dequeue(x, y float64) { + w.n-- + dx := x - w.x.hi - w.x.lo + dy := y - w.y.hi - w.y.lo + w.x.sub(dx / float64(w.n)) + w.y.sub(dy / float64(w.n)) + d2 := y - w.y.hi - w.y.lo + w.c.sub(dx * d2) +} + type kahan struct{ hi, lo float64 } func (k *kahan) add(x float64) { diff --git a/ext/stats/welford_test.go b/ext/stats/welford_test.go index bf4832d..4d7d445 100644 --- a/ext/stats/welford_test.go +++ b/ext/stats/welford_test.go @@ -32,7 +32,38 @@ func Test_welford(t *testing.T) { s2.enqueue(7) s2.enqueue(13) s2.enqueue(16) + s1.m1.lo, s2.m1.lo = 0, 0 + s1.m2.lo, s2.m2.lo = 0, 0 if s1 != s2 { t.Errorf("got %v, want %v", s1, s2) } } + +func Test_covar(t *testing.T) { + var c1, c2 welford2 + + c1.enqueue(3, 70) + c1.enqueue(5, 80) + c1.enqueue(2, 60) + c1.enqueue(7, 90) + c1.enqueue(4, 75) + + if got := c1.covar_samp(); got != 21.25 { + t.Errorf("got %v, want 21.25", got) + } + if got := c1.covar_pop(); got != 17 { + t.Errorf("got %v, want 17", got) + } + + c1.dequeue(3, 70) + c2.enqueue(5, 80) + c2.enqueue(2, 60) + c2.enqueue(7, 90) + c2.enqueue(4, 75) + c1.x.lo, c2.x.lo = 0, 0 + c1.y.lo, c2.y.lo = 0, 0 + c1.c.lo, c2.c.lo = 0, 0 + if c1 != c2 { + t.Errorf("got %v, want %v", c1, c2) + } +} diff --git a/func.go b/func.go index 205943e..4de8b98 100644 --- a/func.go +++ b/func.go @@ -12,7 +12,7 @@ import ( // for any unknown collating sequence. // The fake collating function works like BINARY. // -// This extension can be used to load schemas that contain +// This can be used to load schemas that contain // one or more unknown collating sequences. func (c *Conn) AnyCollationNeeded() { c.call(c.api.anyCollation, uint64(c.handle), 0, 0)