From 9d997552add8009766da63dbe82bc22af8936c41 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 2 Sep 2023 00:48:55 +0100 Subject: [PATCH] Pearson correlation. --- ext/stats/stats.go | 5 +++++ ext/stats/stats_test.go | 9 +++++--- ext/stats/welford.go | 44 +++++++++++++++++++++++++-------------- ext/stats/welford_test.go | 22 +++++++++++++------- 4 files changed, 53 insertions(+), 27 deletions(-) diff --git a/ext/stats/stats.go b/ext/stats/stats.go index 87e822b..bd10ff6 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -7,6 +7,7 @@ // - var_samp: sample variance // - covar_pop: population covariance // - covar_samp: sample covariance +// - corr: correlation coefficient // // See: [ANSI SQL Aggregate Functions] // @@ -24,6 +25,7 @@ func Register(db *sqlite3.Conn) { db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp)) db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop)) db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp)) + db.CreateWindowFunction("corr", 2, flags, newCovariance(corr)) } const ( @@ -31,6 +33,7 @@ const ( var_samp stddev_pop stddev_samp + corr ) func newVariance(kind int) func() sqlite3.AggregateFunction { @@ -85,6 +88,8 @@ func (fn *covariance) Value(ctx sqlite3.Context) { r = fn.covar_pop() case var_samp: r = fn.covar_samp() + case corr: + r = fn.correlation() } ctx.ResultFloat(r) } diff --git a/ext/stats/stats_test.go b/ext/stats/stats_test.go index 7395f29..cdd9a39 100644 --- a/ext/stats/stats_test.go +++ b/ext/stats/stats_test.go @@ -102,17 +102,20 @@ func TestRegister_covariance(t *testing.T) { } stmt, _, err := db.Prepare(`SELECT - covar_samp(x, y), covar_pop(x, y) FROM data`) + corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`) if err != nil { t.Fatal(err) } defer stmt.Close() if stmt.Step() { - if got := stmt.ColumnFloat(0); got != 21.25 { + if got := stmt.ColumnFloat(0); got != 0.9881049293224639 { + t.Errorf("got %v, want 0.9881049293224639", got) + } + if got := stmt.ColumnFloat(1); got != 21.25 { t.Errorf("got %v, want 21.25", got) } - if got := stmt.ColumnFloat(1); got != 17 { + if got := stmt.ColumnFloat(2); got != 17 { t.Errorf("got %v, want 17", got) } } diff --git a/ext/stats/welford.go b/ext/stats/welford.go index a6afab0..62885a8 100644 --- a/ext/stats/welford.go +++ b/ext/stats/welford.go @@ -48,36 +48,48 @@ func (w *welford) dequeue(x float64) { } type welford2 struct { - x, y, c kahan - n uint64 + m1x, m2x kahan + m1y, m2y kahan + cov kahan + n uint64 } func (w welford2) covar_pop() float64 { - return w.c.hi / float64(w.n) + return w.cov.hi / float64(w.n) } func (w welford2) covar_samp() float64 { - return w.c.hi / float64(w.n-1) // Bessel's correction + return w.cov.hi / float64(w.n-1) // Bessel's correction +} + +func (w welford2) correlation() float64 { + return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi) } func (w *welford2) enqueue(x, y float64) { w.n++ - dx := x - w.x.hi - w.x.lo - dy := y - w.y.hi - w.y.lo - w.x.add(dx / float64(w.n)) - w.y.add(dy / float64(w.n)) - d2 := y - w.y.hi - w.y.lo - w.c.add(dx * d2) + d1x := x - w.m1x.hi - w.m1x.lo + d1y := y - w.m1y.hi - w.m1y.lo + w.m1x.add(d1x / float64(w.n)) + w.m1y.add(d1y / float64(w.n)) + d2x := x - w.m1x.hi - w.m1x.lo + d2y := y - w.m1y.hi - w.m1y.lo + w.m2x.add(d1x * d2x) + w.m2y.add(d1y * d2y) + w.cov.add(d1x * d2y) } func (w *welford2) dequeue(x, y float64) { w.n-- - dx := x - w.x.hi - w.x.lo - dy := y - w.y.hi - w.y.lo - w.x.sub(dx / float64(w.n)) - w.y.sub(dy / float64(w.n)) - d2 := y - w.y.hi - w.y.lo - w.c.sub(dx * d2) + d1x := x - w.m1x.hi - w.m1x.lo + d1y := y - w.m1y.hi - w.m1y.lo + w.m1x.sub(d1x / float64(w.n)) + w.m1y.sub(d1y / float64(w.n)) + d2x := x - w.m1x.hi - w.m1x.lo + d2y := y - w.m1y.hi - w.m1y.lo + w.m2x.sub(d1x * d2x) + w.m2y.sub(d1y * d2y) + w.cov.sub(d1x * d2y) } type kahan struct{ hi, lo float64 } diff --git a/ext/stats/welford_test.go b/ext/stats/welford_test.go index 4d7d445..405053e 100644 --- a/ext/stats/welford_test.go +++ b/ext/stats/welford_test.go @@ -32,9 +32,7 @@ func Test_welford(t *testing.T) { s2.enqueue(7) s2.enqueue(13) s2.enqueue(16) - s1.m1.lo, s2.m1.lo = 0, 0 - s1.m2.lo, s2.m2.lo = 0, 0 - if s1 != s2 { + if s1.var_pop() != s2.var_pop() { t.Errorf("got %v, want %v", s1, s2) } } @@ -60,10 +58,18 @@ func Test_covar(t *testing.T) { c2.enqueue(2, 60) c2.enqueue(7, 90) c2.enqueue(4, 75) - c1.x.lo, c2.x.lo = 0, 0 - c1.y.lo, c2.y.lo = 0, 0 - c1.c.lo, c2.c.lo = 0, 0 - if c1 != c2 { - t.Errorf("got %v, want %v", c1, c2) + if c1.covar_pop() != c2.covar_pop() { + t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop()) + } +} + +func Test_correlation(t *testing.T) { + var c welford2 + c.enqueue(1, 3) + c.enqueue(2, 2) + c.enqueue(3, 1) + + if got := c.correlation(); got != -1 { + t.Errorf("got %v, want -1", got) } }