mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Pearson correlation.
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
// - var_samp: sample variance
|
||||
// - covar_pop: population covariance
|
||||
// - covar_samp: sample covariance
|
||||
// - corr: correlation coefficient
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions]
|
||||
//
|
||||
@@ -24,6 +25,7 @@ func Register(db *sqlite3.Conn) {
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
|
||||
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -31,6 +33,7 @@ const (
|
||||
var_samp
|
||||
stddev_pop
|
||||
stddev_samp
|
||||
corr
|
||||
)
|
||||
|
||||
func newVariance(kind int) func() sqlite3.AggregateFunction {
|
||||
@@ -85,6 +88,8 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
r = fn.covar_pop()
|
||||
case var_samp:
|
||||
r = fn.covar_samp()
|
||||
case corr:
|
||||
r = fn.correlation()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
@@ -102,17 +102,20 @@ func TestRegister_covariance(t *testing.T) {
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
covar_samp(x, y), covar_pop(x, y) FROM data`)
|
||||
corr(x, y), covar_samp(x, y), covar_pop(x, y) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 21.25 {
|
||||
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
|
||||
t.Errorf("got %v, want 0.9881049293224639", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 17 {
|
||||
if got := stmt.ColumnFloat(2); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,36 +48,48 @@ func (w *welford) dequeue(x float64) {
|
||||
}
|
||||
|
||||
type welford2 struct {
|
||||
x, y, c kahan
|
||||
n uint64
|
||||
m1x, m2x kahan
|
||||
m1y, m2y kahan
|
||||
cov kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford2) covar_pop() float64 {
|
||||
return w.c.hi / float64(w.n)
|
||||
return w.cov.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford2) covar_samp() float64 {
|
||||
return w.c.hi / float64(w.n-1) // Bessel's correction
|
||||
return w.cov.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford2) correlation() float64 {
|
||||
return w.cov.hi / math.Sqrt(w.m2x.hi*w.m2y.hi)
|
||||
}
|
||||
|
||||
func (w *welford2) enqueue(x, y float64) {
|
||||
w.n++
|
||||
dx := x - w.x.hi - w.x.lo
|
||||
dy := y - w.y.hi - w.y.lo
|
||||
w.x.add(dx / float64(w.n))
|
||||
w.y.add(dy / float64(w.n))
|
||||
d2 := y - w.y.hi - w.y.lo
|
||||
w.c.add(dx * d2)
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.add(d1x / float64(w.n))
|
||||
w.m1y.add(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.add(d1x * d2x)
|
||||
w.m2y.add(d1y * d2y)
|
||||
w.cov.add(d1x * d2y)
|
||||
}
|
||||
|
||||
func (w *welford2) dequeue(x, y float64) {
|
||||
w.n--
|
||||
dx := x - w.x.hi - w.x.lo
|
||||
dy := y - w.y.hi - w.y.lo
|
||||
w.x.sub(dx / float64(w.n))
|
||||
w.y.sub(dy / float64(w.n))
|
||||
d2 := y - w.y.hi - w.y.lo
|
||||
w.c.sub(dx * d2)
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m1x.sub(d1x / float64(w.n))
|
||||
w.m1y.sub(d1y / float64(w.n))
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
w.m2x.sub(d1x * d2x)
|
||||
w.m2y.sub(d1y * d2y)
|
||||
w.cov.sub(d1x * d2y)
|
||||
}
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
@@ -32,9 +32,7 @@ func Test_welford(t *testing.T) {
|
||||
s2.enqueue(7)
|
||||
s2.enqueue(13)
|
||||
s2.enqueue(16)
|
||||
s1.m1.lo, s2.m1.lo = 0, 0
|
||||
s1.m2.lo, s2.m2.lo = 0, 0
|
||||
if s1 != s2 {
|
||||
if s1.var_pop() != s2.var_pop() {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
@@ -60,10 +58,18 @@ func Test_covar(t *testing.T) {
|
||||
c2.enqueue(2, 60)
|
||||
c2.enqueue(7, 90)
|
||||
c2.enqueue(4, 75)
|
||||
c1.x.lo, c2.x.lo = 0, 0
|
||||
c1.y.lo, c2.y.lo = 0, 0
|
||||
c1.c.lo, c2.c.lo = 0, 0
|
||||
if c1 != c2 {
|
||||
t.Errorf("got %v, want %v", c1, c2)
|
||||
if c1.covar_pop() != c2.covar_pop() {
|
||||
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_correlation(t *testing.T) {
|
||||
var c welford2
|
||||
c.enqueue(1, 3)
|
||||
c.enqueue(2, 2)
|
||||
c.enqueue(3, 1)
|
||||
|
||||
if got := c.correlation(); got != -1 {
|
||||
t.Errorf("got %v, want -1", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user