Pearson correlation.

This commit is contained in:
Nuno Cruces
2023-09-02 00:48:55 +01:00
parent 9d75c39dcc
commit 9d997552ad
4 changed files with 53 additions and 27 deletions

View File

@@ -7,6 +7,7 @@
// - var_samp: sample variance
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: correlation coefficient
//
// See: [ANSI SQL Aggregate Functions]
//
@@ -24,6 +25,7 @@ func Register(db *sqlite3.Conn) {
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
}
const (
@@ -31,6 +33,7 @@ const (
var_samp
stddev_pop
stddev_samp
corr
)
func newVariance(kind int) func() sqlite3.AggregateFunction {
@@ -85,6 +88,8 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
r = fn.covar_pop()
case var_samp:
r = fn.covar_samp()
case corr:
r = fn.correlation()
}
ctx.ResultFloat(r)
}

View File

@@ -102,17 +102,20 @@ func TestRegister_covariance(t *testing.T) {
}
stmt, _, err := db.Prepare(`SELECT
covar_samp(x, y), covar_pop(x, y) FROM data`)
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 21.25 {
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := stmt.ColumnFloat(1); got != 17 {
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
}

View File

@@ -48,36 +48,48 @@ func (w *welford) dequeue(x float64) {
}
type welford2 struct {
x, y, c kahan
n uint64
m1x, m2x kahan
m1y, m2y kahan
cov kahan
n uint64
}
func (w welford2) covar_pop() float64 {
return w.c.hi / float64(w.n)
return w.cov.hi / float64(w.n)
}
func (w welford2) covar_samp() float64 {
return w.c.hi / float64(w.n-1) // Bessel's correction
return w.cov.hi / float64(w.n-1) // Bessel's correction
}
func (w welford2) correlation() float64 {
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
}
func (w *welford2) enqueue(x, y float64) {
w.n++
dx := x - w.x.hi - w.x.lo
dy := y - w.y.hi - w.y.lo
w.x.add(dx / float64(w.n))
w.y.add(dy / float64(w.n))
d2 := y - w.y.hi - w.y.lo
w.c.add(dx * d2)
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.add(d1x / float64(w.n))
w.m1y.add(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.add(d1x * d2x)
w.m2y.add(d1y * d2y)
w.cov.add(d1x * d2y)
}
func (w *welford2) dequeue(x, y float64) {
w.n--
dx := x - w.x.hi - w.x.lo
dy := y - w.y.hi - w.y.lo
w.x.sub(dx / float64(w.n))
w.y.sub(dy / float64(w.n))
d2 := y - w.y.hi - w.y.lo
w.c.sub(dx * d2)
d1x := x - w.m1x.hi - w.m1x.lo
d1y := y - w.m1y.hi - w.m1y.lo
w.m1x.sub(d1x / float64(w.n))
w.m1y.sub(d1y / float64(w.n))
d2x := x - w.m1x.hi - w.m1x.lo
d2y := y - w.m1y.hi - w.m1y.lo
w.m2x.sub(d1x * d2x)
w.m2y.sub(d1y * d2y)
w.cov.sub(d1x * d2y)
}
type kahan struct{ hi, lo float64 }

View File

@@ -32,9 +32,7 @@ func Test_welford(t *testing.T) {
s2.enqueue(7)
s2.enqueue(13)
s2.enqueue(16)
s1.m1.lo, s2.m1.lo = 0, 0
s1.m2.lo, s2.m2.lo = 0, 0
if s1 != s2 {
if s1.var_pop() != s2.var_pop() {
t.Errorf("got %v, want %v", s1, s2)
}
}
@@ -60,10 +58,18 @@ func Test_covar(t *testing.T) {
c2.enqueue(2, 60)
c2.enqueue(7, 90)
c2.enqueue(4, 75)
c1.x.lo, c2.x.lo = 0, 0
c1.y.lo, c2.y.lo = 0, 0
c1.c.lo, c2.c.lo = 0, 0
if c1 != c2 {
t.Errorf("got %v, want %v", c1, c2)
if c1.covar_pop() != c2.covar_pop() {
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
}
}
func Test_correlation(t *testing.T) {
var c welford2
c.enqueue(1, 3)
c.enqueue(2, 2)
c.enqueue(3, 1)
if got := c.correlation(); got != -1 {
t.Errorf("got %v, want -1", got)
}
}