mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Covariance.
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
// Package stats provides aggregate functions for statistics.
|
||||
//
|
||||
// Functions:
|
||||
// - var_samp: sample variance
|
||||
// - var_pop: population variance
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - stddev_pop: population standard deviation
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - var_pop: population variance
|
||||
// - var_samp: sample variance
|
||||
// - covar_pop: population covariance
|
||||
// - covar_samp: sample covariance
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions]
|
||||
//
|
||||
@@ -16,10 +18,12 @@ import "github.com/ncruces/go-sqlite3"
|
||||
// Register registers statistics functions.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateWindowFunction("var_pop", 1, flags, create(var_pop))
|
||||
db.CreateWindowFunction("var_samp", 1, flags, create(var_samp))
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, create(stddev_pop))
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, create(stddev_samp))
|
||||
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop))
|
||||
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp))
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, newVariance(stddev_pop))
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, newVariance(stddev_samp))
|
||||
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
|
||||
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -29,38 +33,72 @@ const (
|
||||
stddev_samp
|
||||
)
|
||||
|
||||
func create(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &state{kind: kind} }
|
||||
func newVariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &variance{kind: kind} }
|
||||
}
|
||||
|
||||
type state struct {
|
||||
type variance struct {
|
||||
kind int
|
||||
welford
|
||||
}
|
||||
|
||||
func (f *state) Value(ctx sqlite3.Context) {
|
||||
func (fn *variance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch f.kind {
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = f.var_pop()
|
||||
r = fn.var_pop()
|
||||
case var_samp:
|
||||
r = f.var_samp()
|
||||
r = fn.var_samp()
|
||||
case stddev_pop:
|
||||
r = f.stddev_pop()
|
||||
r = fn.stddev_pop()
|
||||
case stddev_samp:
|
||||
r = f.stddev_samp()
|
||||
r = fn.stddev_samp()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (f *state) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
func (fn *variance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
f.enqueue(a.Float())
|
||||
fn.enqueue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (f *state) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
func (fn *variance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
f.dequeue(a.Float())
|
||||
fn.dequeue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func newCovariance(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &covariance{kind: kind} }
|
||||
}
|
||||
|
||||
type covariance struct {
|
||||
kind int
|
||||
welford2
|
||||
}
|
||||
|
||||
func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
r = fn.covar_pop()
|
||||
case var_samp:
|
||||
r = fn.covar_samp()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (fn *covariance) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.enqueue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (fn *covariance) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
a, b := arg[0], arg[1]
|
||||
if a.Type() != sqlite3.NULL && b.Type() != sqlite3.NULL {
|
||||
fn.dequeue(a.Float(), b.Float())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
func TestRegister_variance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
@@ -19,20 +19,22 @@ func TestRegister(t *testing.T) {
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (col)`)
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (col) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
sum(col), avg(col),
|
||||
var_samp(col), var_pop(col),
|
||||
stddev_samp(col), stddev_pop(col) FROM data`)
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT
|
||||
sum(x), avg(x),
|
||||
var_samp(x), var_pop(x),
|
||||
stddev_samp(x), stddev_pop(x)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -60,7 +62,7 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT var_samp(col) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -77,3 +79,59 @@ func TestRegister(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_covariance(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
covar_samp(x, y), covar_pop(x, y) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 10, 30, 75, 22.5}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func (w welford) var_pop() float64 {
|
||||
}
|
||||
|
||||
func (w welford) var_samp() float64 {
|
||||
return w.m2.hi / float64(w.n-1)
|
||||
return w.m2.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w welford) stddev_pop() float64 {
|
||||
@@ -33,20 +33,53 @@ func (w welford) stddev_samp() float64 {
|
||||
|
||||
func (w *welford) enqueue(x float64) {
|
||||
w.n++
|
||||
d1 := x - w.m1.hi
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.add(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.add(d1 * d2)
|
||||
}
|
||||
|
||||
func (w *welford) dequeue(x float64) {
|
||||
w.n--
|
||||
d1 := x - w.m1.hi
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.sub(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.sub(d1 * d2)
|
||||
}
|
||||
|
||||
type welford2 struct {
|
||||
x, y, c kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford2) covar_pop() float64 {
|
||||
return w.c.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford2) covar_samp() float64 {
|
||||
return w.c.hi / float64(w.n-1) // Bessel's correction
|
||||
}
|
||||
|
||||
func (w *welford2) enqueue(x, y float64) {
|
||||
w.n++
|
||||
dx := x - w.x.hi - w.x.lo
|
||||
dy := y - w.y.hi - w.y.lo
|
||||
w.x.add(dx / float64(w.n))
|
||||
w.y.add(dy / float64(w.n))
|
||||
d2 := y - w.y.hi - w.y.lo
|
||||
w.c.add(dx * d2)
|
||||
}
|
||||
|
||||
func (w *welford2) dequeue(x, y float64) {
|
||||
w.n--
|
||||
dx := x - w.x.hi - w.x.lo
|
||||
dy := y - w.y.hi - w.y.lo
|
||||
w.x.sub(dx / float64(w.n))
|
||||
w.y.sub(dy / float64(w.n))
|
||||
d2 := y - w.y.hi - w.y.lo
|
||||
w.c.sub(dx * d2)
|
||||
}
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
func (k *kahan) add(x float64) {
|
||||
|
||||
@@ -32,7 +32,38 @@ func Test_welford(t *testing.T) {
|
||||
s2.enqueue(7)
|
||||
s2.enqueue(13)
|
||||
s2.enqueue(16)
|
||||
s1.m1.lo, s2.m1.lo = 0, 0
|
||||
s1.m2.lo, s2.m2.lo = 0, 0
|
||||
if s1 != s2 {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_covar(t *testing.T) {
|
||||
var c1, c2 welford2
|
||||
|
||||
c1.enqueue(3, 70)
|
||||
c1.enqueue(5, 80)
|
||||
c1.enqueue(2, 60)
|
||||
c1.enqueue(7, 90)
|
||||
c1.enqueue(4, 75)
|
||||
|
||||
if got := c1.covar_samp(); got != 21.25 {
|
||||
t.Errorf("got %v, want 21.25", got)
|
||||
}
|
||||
if got := c1.covar_pop(); got != 17 {
|
||||
t.Errorf("got %v, want 17", got)
|
||||
}
|
||||
|
||||
c1.dequeue(3, 70)
|
||||
c2.enqueue(5, 80)
|
||||
c2.enqueue(2, 60)
|
||||
c2.enqueue(7, 90)
|
||||
c2.enqueue(4, 75)
|
||||
c1.x.lo, c2.x.lo = 0, 0
|
||||
c1.y.lo, c2.y.lo = 0, 0
|
||||
c1.c.lo, c2.c.lo = 0, 0
|
||||
if c1 != c2 {
|
||||
t.Errorf("got %v, want %v", c1, c2)
|
||||
}
|
||||
}
|
||||
|
||||
2
func.go
2
func.go
@@ -12,7 +12,7 @@ import (
|
||||
// for any unknown collating sequence.
|
||||
// The fake collating function works like BINARY.
|
||||
//
|
||||
// This extension can be used to load schemas that contain
|
||||
// This can be used to load schemas that contain
|
||||
// one or more unknown collating sequences.
|
||||
func (c *Conn) AnyCollationNeeded() {
|
||||
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
|
||||
|
||||
Reference in New Issue
Block a user