mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Statistics functions.
This commit is contained in:
66
ext/stats/stats.go
Normal file
66
ext/stats/stats.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package stats provides aggregate functions for statistics.
|
||||
//
|
||||
// Functions:
|
||||
// - var_samp: sample variance
|
||||
// - var_pop: population variance
|
||||
// - stddev_samp: sample standard deviation
|
||||
// - stddev_pop: population standard deviation
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions]
|
||||
//
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
package stats
|
||||
|
||||
import "github.com/ncruces/go-sqlite3"
|
||||
|
||||
// Register registers statistics functions.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
db.CreateWindowFunction("var_pop", 1, flags, create(var_pop))
|
||||
db.CreateWindowFunction("var_samp", 1, flags, create(var_samp))
|
||||
db.CreateWindowFunction("stddev_pop", 1, flags, create(stddev_pop))
|
||||
db.CreateWindowFunction("stddev_samp", 1, flags, create(stddev_samp))
|
||||
}
|
||||
|
||||
const (
|
||||
var_pop = iota
|
||||
var_samp
|
||||
stddev_pop
|
||||
stddev_samp
|
||||
)
|
||||
|
||||
func create(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &state{kind: kind} }
|
||||
}
|
||||
|
||||
type state struct {
|
||||
kind int
|
||||
welford
|
||||
}
|
||||
|
||||
func (f *state) Value(ctx sqlite3.Context) {
|
||||
var r float64
|
||||
switch f.kind {
|
||||
case var_pop:
|
||||
r = f.var_pop()
|
||||
case var_samp:
|
||||
r = f.var_samp()
|
||||
case stddev_pop:
|
||||
r = f.stddev_pop()
|
||||
case stddev_samp:
|
||||
r = f.stddev_samp()
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
}
|
||||
|
||||
func (f *state) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
f.enqueue(a.Float())
|
||||
}
|
||||
}
|
||||
|
||||
func (f *state) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.Type() != sqlite3.NULL {
|
||||
f.dequeue(a.Float())
|
||||
}
|
||||
}
|
||||
79
ext/stats/stats_test.go
Normal file
79
ext/stats/stats_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (col)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (col) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
sum(col), avg(col),
|
||||
var_samp(col), var_pop(col),
|
||||
stddev_samp(col), stddev_pop(col) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 40 {
|
||||
t.Errorf("got %v, want 40", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(3); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
stmt, _, err := db.Prepare(`SELECT var_samp(col) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
want := [...]float64{0, 4.5, 18, 0, 0}
|
||||
for i := 0; stmt.Step(); i++ {
|
||||
if got := stmt.ColumnFloat(0); got != want[i] {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
64
ext/stats/welford.go
Normal file
64
ext/stats/welford.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package stats
|
||||
|
||||
import "math"
|
||||
|
||||
// Welford's algorithm with Kahan summation:
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
|
||||
|
||||
type welford struct {
|
||||
m1, m2 kahan
|
||||
n uint64
|
||||
}
|
||||
|
||||
func (w welford) average() float64 {
|
||||
return w.m1.hi
|
||||
}
|
||||
|
||||
func (w welford) var_pop() float64 {
|
||||
return w.m2.hi / float64(w.n)
|
||||
}
|
||||
|
||||
func (w welford) var_samp() float64 {
|
||||
return w.m2.hi / float64(w.n-1)
|
||||
}
|
||||
|
||||
func (w welford) stddev_pop() float64 {
|
||||
return math.Sqrt(w.var_pop())
|
||||
}
|
||||
|
||||
func (w welford) stddev_samp() float64 {
|
||||
return math.Sqrt(w.var_samp())
|
||||
}
|
||||
|
||||
func (w *welford) enqueue(x float64) {
|
||||
w.n++
|
||||
d1 := x - w.m1.hi
|
||||
w.m1.add(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi
|
||||
w.m2.add(d1 * d2)
|
||||
}
|
||||
|
||||
func (w *welford) dequeue(x float64) {
|
||||
w.n--
|
||||
d1 := x - w.m1.hi
|
||||
w.m1.sub(d1 / float64(w.n))
|
||||
d2 := x - w.m1.hi
|
||||
w.m2.sub(d1 * d2)
|
||||
}
|
||||
|
||||
type kahan struct{ hi, lo float64 }
|
||||
|
||||
func (k *kahan) add(x float64) {
|
||||
y := k.lo + x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
|
||||
func (k *kahan) sub(x float64) {
|
||||
y := k.lo - x
|
||||
t := k.hi + y
|
||||
k.lo = y - (t - k.hi)
|
||||
k.hi = t
|
||||
}
|
||||
38
ext/stats/welford_test.go
Normal file
38
ext/stats/welford_test.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_welford(t *testing.T) {
|
||||
var s1, s2 welford
|
||||
|
||||
s1.enqueue(4)
|
||||
s1.enqueue(7)
|
||||
s1.enqueue(13)
|
||||
s1.enqueue(16)
|
||||
if got := s1.average(); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := s1.var_samp(); got != 30 {
|
||||
t.Errorf("got %v, want 30", got)
|
||||
}
|
||||
if got := s1.var_pop(); got != 22.5 {
|
||||
t.Errorf("got %v, want 22.5", got)
|
||||
}
|
||||
if got := s1.stddev_samp(); got != math.Sqrt(30) {
|
||||
t.Errorf("got %v, want √30", got)
|
||||
}
|
||||
if got := s1.stddev_pop(); got != math.Sqrt(22.5) {
|
||||
t.Errorf("got %v, want √22.5", got)
|
||||
}
|
||||
|
||||
s1.dequeue(4)
|
||||
s2.enqueue(7)
|
||||
s2.enqueue(13)
|
||||
s2.enqueue(16)
|
||||
if s1 != s2 {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,19 @@
|
||||
// Package unicode provides an alternative to the SQLite ICU extension.
|
||||
//
|
||||
// Provides Unicode aware:
|
||||
// - upper and lower functions,
|
||||
// Like the [ICU extension], it provides Unicode aware:
|
||||
// - upper() and lower() functions,
|
||||
// - LIKE and REGEXP operators,
|
||||
// - collation sequences.
|
||||
//
|
||||
// This package is not 100% compatible with the ICU extension:
|
||||
// - upper and lower use [strings.ToUpper], [strings.ToLower] and [cases];
|
||||
// The implementation is not 100% compatible with the [ICU extension]:
|
||||
// - upper() and lower() use [strings.ToUpper], [strings.ToLower] and [cases];
|
||||
// - the LIKE operator follows [strings.EqualFold] rules;
|
||||
// - the REGEXP operator uses Go [regex/syntax];
|
||||
// - collation sequences use [collate].
|
||||
//
|
||||
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
|
||||
//
|
||||
// [ICU extension]: https://sqlite.org/src/dir/ext/icu
|
||||
package unicode
|
||||
|
||||
import (
|
||||
@@ -45,7 +47,7 @@ func Register(db *sqlite3.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
err := RegisterCollation(db, name, arg[0].Text())
|
||||
err := RegisterCollation(db, arg[0].Text(), name)
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return
|
||||
@@ -53,8 +55,9 @@ func Register(db *sqlite3.Conn) {
|
||||
})
|
||||
}
|
||||
|
||||
func RegisterCollation(db *sqlite3.Conn, name, lang string) error {
|
||||
tag, err := language.Parse(lang)
|
||||
// RegisterCollation registers a Unicode collation sequence for a database connection.
|
||||
func RegisterCollation(db *sqlite3.Conn, locale, name string) error {
|
||||
tag, err := language.Parse(locale)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ func ExampleConn_CreateWindowFunction() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.INNOCUOUS, newASCIICounter)
|
||||
err = db.CreateWindowFunction("count_ascii", 1, sqlite3.DETERMINISTIC|sqlite3.INNOCUOUS, newASCIICounter)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user