Statistics functions.

This commit is contained in:
Nuno Cruces
2023-08-31 16:30:52 +01:00
parent b71cd295c2
commit 312d3b58f2
6 changed files with 258 additions and 8 deletions

66
ext/stats/stats.go Normal file
View File

@@ -0,0 +1,66 @@
// Package stats provides aggregate functions for statistics.
//
// Functions:
// - var_samp: sample variance
// - var_pop: population variance
// - stddev_samp: sample standard deviation
// - stddev_pop: population standard deviation
//
// See: [ANSI SQL Aggregate Functions]
//
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
package stats
import "github.com/ncruces/go-sqlite3"
// Register registers statistics functions.
func Register(db *sqlite3.Conn) {
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
db.CreateWindowFunction("var_pop", 1, flags, create(var_pop))
db.CreateWindowFunction("var_samp", 1, flags, create(var_samp))
db.CreateWindowFunction("stddev_pop", 1, flags, create(stddev_pop))
db.CreateWindowFunction("stddev_samp", 1, flags, create(stddev_samp))
}
const (
var_pop = iota
var_samp
stddev_pop
stddev_samp
)
func create(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &state{kind: kind} }
}
type state struct {
kind int
welford
}
func (f *state) Value(ctx sqlite3.Context) {
var r float64
switch f.kind {
case var_pop:
r = f.var_pop()
case var_samp:
r = f.var_samp()
case stddev_pop:
r = f.stddev_pop()
case stddev_samp:
r = f.stddev_samp()
}
ctx.ResultFloat(r)
}
func (f *state) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
f.enqueue(a.Float())
}
}
func (f *state) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.Type() != sqlite3.NULL {
f.dequeue(a.Float())
}
}

79
ext/stats/stats_test.go Normal file
View File

@@ -0,0 +1,79 @@
package stats
import (
"math"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestRegister(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (col)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (col) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
sum(col), avg(col),
var_samp(col), var_pop(col),
stddev_samp(col), stddev_pop(col) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
}
if got := stmt.ColumnFloat(1); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(2); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := stmt.ColumnFloat(3); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT var_samp(col) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}

64
ext/stats/welford.go Normal file
View File

@@ -0,0 +1,64 @@
package stats
import "math"
// Welford's algorithm with Kahan summation:
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
type welford struct {
m1, m2 kahan
n uint64
}
func (w welford) average() float64 {
return w.m1.hi
}
func (w welford) var_pop() float64 {
return w.m2.hi / float64(w.n)
}
func (w welford) var_samp() float64 {
return w.m2.hi / float64(w.n-1)
}
func (w welford) stddev_pop() float64 {
return math.Sqrt(w.var_pop())
}
func (w welford) stddev_samp() float64 {
return math.Sqrt(w.var_samp())
}
func (w *welford) enqueue(x float64) {
w.n++
d1 := x - w.m1.hi
w.m1.add(d1 / float64(w.n))
d2 := x - w.m1.hi
w.m2.add(d1 * d2)
}
func (w *welford) dequeue(x float64) {
w.n--
d1 := x - w.m1.hi
w.m1.sub(d1 / float64(w.n))
d2 := x - w.m1.hi
w.m2.sub(d1 * d2)
}
type kahan struct{ hi, lo float64 }
func (k *kahan) add(x float64) {
y := k.lo + x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}
func (k *kahan) sub(x float64) {
y := k.lo - x
t := k.hi + y
k.lo = y - (t - k.hi)
k.hi = t
}

38
ext/stats/welford_test.go Normal file
View File

@@ -0,0 +1,38 @@
package stats
import (
"math"
"testing"
)
func Test_welford(t *testing.T) {
var s1, s2 welford
s1.enqueue(4)
s1.enqueue(7)
s1.enqueue(13)
s1.enqueue(16)
if got := s1.average(); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := s1.var_samp(); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := s1.var_pop(); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := s1.stddev_samp(); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
s1.dequeue(4)
s2.enqueue(7)
s2.enqueue(13)
s2.enqueue(16)
if s1 != s2 {
t.Errorf("got %v, want %v", s1, s2)
}
}

View File

@@ -1,17 +1,19 @@
// Package unicode provides an alternative to the SQLite ICU extension.
//
// Provides Unicode aware:
// - upper and lower functions,
// Like the [ICU extension], it provides Unicode aware:
// - upper() and lower() functions,
// - LIKE and REGEXP operators,
// - collation sequences.
//
// This package is not 100% compatible with the ICU extension:
// - upper and lower use [strings.ToUpper], [strings.ToLower] and [cases];
// The implementation is not 100% compatible with the [ICU extension]:
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
// - the LIKE operator follows [strings.EqualFold] rules;
// - the REGEXP operator uses Go [regex/syntax];
// - collation sequences use [collate].
//
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
//
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
package unicode
import (
@@ -45,7 +47,7 @@ func Register(db *sqlite3.Conn) {
return
}
err := RegisterCollation(db, name, arg[0].Text())
err := RegisterCollation(db, arg[0].Text(), name)
if err != nil {
ctx.ResultError(err)
return
@@ -53,8 +55,9 @@ func Register(db *sqlite3.Conn) {
})
}
func RegisterCollation(db *sqlite3.Conn, name, lang string) error {
tag, err := language.Parse(lang)
// RegisterCollation registers a Unicode collation sequence for a database connection.
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
tag, err := language.Parse(locale)
if err != nil {
return err
}

View File

@@ -26,7 +26,7 @@ func ExampleConn_CreateWindowFunction() {
log.Fatal(err)
}
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter)
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
if err != nil {
log.Fatal(err)
}