mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Fix numerical issues.
This commit is contained in:
38
ext/stats/moments.nogo
Normal file
38
ext/stats/moments.nogo
Normal file
@@ -0,0 +1,38 @@
|
||||
package stats
|
||||
|
||||
import "math"
|
||||
|
||||
type moment struct {
|
||||
m1, m2, m3, m4 kahan
|
||||
n int64
|
||||
}
|
||||
|
||||
func (w *moment) enqueue(x float64) {
|
||||
n := w.n + 1
|
||||
w.n = n
|
||||
y := x - w.m1.hi - w.m1.lo
|
||||
w.m1.add(y / float64(n))
|
||||
y = math.FMA(y, x, -w.m2.hi) - w.m2.lo
|
||||
w.m2.add(y / float64(n))
|
||||
y = math.FMA(y, x, -w.m3.hi) - w.m3.lo
|
||||
w.m3.add(y / float64(n))
|
||||
y = math.FMA(y, x, -w.m4.hi) - w.m4.lo
|
||||
w.m4.add(y / float64(n))
|
||||
}
|
||||
|
||||
func (w *moment) dequeue(x float64) {
|
||||
n := w.n - 1
|
||||
if n <= 0 {
|
||||
*w = moment{}
|
||||
return
|
||||
}
|
||||
w.n = n
|
||||
y := x - w.m1.hi + w.m1.lo
|
||||
w.m1.sub(y / float64(n))
|
||||
y = math.FMA(y, x, w.m2.hi) + w.m2.lo
|
||||
w.m2.sub(y / float64(n))
|
||||
y = math.FMA(y, x, w.m3.hi) + w.m3.lo
|
||||
w.m3.sub(y / float64(n))
|
||||
y = math.FMA(y, x, w.m4.hi) + w.m4.lo
|
||||
w.m4.sub(y / float64(n))
|
||||
}
|
||||
@@ -11,6 +11,9 @@ import (
|
||||
"github.com/ncruces/sort/quick"
|
||||
)
|
||||
|
||||
// Compatible with:
|
||||
// https://sqlite.org/src/file/ext/misc/percentile.c
|
||||
|
||||
const (
|
||||
median = iota
|
||||
percentile_100
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
// - regr_count: count non-null pairs of variables
|
||||
// - regr_slope: slope of the least-squares-fit linear equation
|
||||
// - regr_intercept: y-intercept of the least-squares-fit linear equation
|
||||
// - regr_json: all regr stats in a JSON object
|
||||
// - regr_json: all regr stats as a JSON object
|
||||
// - percentile_disc: discrete quantile
|
||||
// - percentile_cont: continuous quantile
|
||||
// - percentile: continuous percentile
|
||||
@@ -111,6 +111,17 @@ type variance struct {
|
||||
}
|
||||
|
||||
func (fn *variance) Value(ctx sqlite3.Context) {
|
||||
switch fn.n {
|
||||
case 1:
|
||||
switch fn.kind {
|
||||
case var_pop, stddev_pop:
|
||||
ctx.ResultFloat(0)
|
||||
}
|
||||
return
|
||||
case 0:
|
||||
return
|
||||
}
|
||||
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
@@ -151,6 +162,25 @@ type covariance struct {
|
||||
}
|
||||
|
||||
func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
if fn.kind == regr_count {
|
||||
ctx.ResultInt64(fn.regr_count())
|
||||
return
|
||||
}
|
||||
switch fn.n {
|
||||
case 1:
|
||||
switch fn.kind {
|
||||
case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy:
|
||||
ctx.ResultFloat(0)
|
||||
return
|
||||
case regr_avgx, regr_avgy:
|
||||
break
|
||||
default:
|
||||
return
|
||||
}
|
||||
case 0:
|
||||
return
|
||||
}
|
||||
|
||||
var r float64
|
||||
switch fn.kind {
|
||||
case var_pop:
|
||||
@@ -175,11 +205,9 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
|
||||
r = fn.regr_slope()
|
||||
case regr_intercept:
|
||||
r = fn.regr_intercept()
|
||||
case regr_count:
|
||||
ctx.ResultInt64(fn.regr_count())
|
||||
return
|
||||
case regr_json:
|
||||
ctx.ResultText(fn.regr_json())
|
||||
var buf [128]byte
|
||||
ctx.ResultRawText(fn.regr_json(buf[:0]))
|
||||
return
|
||||
}
|
||||
ctx.ResultFloat(r)
|
||||
|
||||
@@ -29,12 +29,23 @@ func TestRegister_variance(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnType(0); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT
|
||||
sum(x), avg(x),
|
||||
var_samp(x), var_pop(x),
|
||||
@@ -65,7 +76,11 @@ func TestRegister_variance(t *testing.T) {
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT
|
||||
var_samp(x) OVER (ROWS 1 PRECEDING),
|
||||
var_pop(x) OVER (ROWS 1 PRECEDING)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -96,12 +111,26 @@ func TestRegister_covariance(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnInt(0); got != 0 {
|
||||
t.Errorf("got %v, want 0", got)
|
||||
}
|
||||
if got := stmt.ColumnType(1); got != sqlite3.NULL {
|
||||
t.Errorf("got %v, want NULL", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT
|
||||
stmt, _, err = db.Prepare(`SELECT
|
||||
corr(y, x), covar_samp(y, x), covar_pop(y, x),
|
||||
regr_avgy(y, x), regr_avgx(y, x),
|
||||
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
|
||||
@@ -157,7 +186,12 @@ func TestRegister_covariance(t *testing.T) {
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT
|
||||
covar_samp(y, x) OVER (ROWS 1 PRECEDING),
|
||||
covar_pop(y, x) OVER (ROWS 1 PRECEDING),
|
||||
regr_avgx(y, x) OVER (ROWS 1 PRECEDING)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -171,6 +205,9 @@ func TestRegister_covariance(t *testing.T) {
|
||||
t.Errorf("got %v, want %v", got, want[i])
|
||||
}
|
||||
}
|
||||
if stmt.Err() != nil {
|
||||
t.Fatal(stmt.Err())
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
|
||||
@@ -3,16 +3,15 @@ package stats
|
||||
import (
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
)
|
||||
|
||||
// Welford's algorithm with Kahan summation:
|
||||
// The effect of truncation in statistical computation [van Reeken, AJ 1970]
|
||||
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
|
||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
|
||||
|
||||
// See also:
|
||||
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates
|
||||
|
||||
type welford struct {
|
||||
m1, m2 kahan
|
||||
n int64
|
||||
@@ -39,17 +38,23 @@ func (w welford) stddev_samp() float64 {
|
||||
}
|
||||
|
||||
func (w *welford) enqueue(x float64) {
|
||||
w.n++
|
||||
n := w.n + 1
|
||||
w.n = n
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.add(d1 / float64(w.n))
|
||||
w.m1.add(d1 / float64(n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.add(d1 * d2)
|
||||
}
|
||||
|
||||
func (w *welford) dequeue(x float64) {
|
||||
w.n--
|
||||
n := w.n - 1
|
||||
if n <= 0 {
|
||||
*w = welford{}
|
||||
return
|
||||
}
|
||||
w.n = n
|
||||
d1 := x - w.m1.hi - w.m1.lo
|
||||
w.m1.sub(d1 / float64(w.n))
|
||||
w.m1.sub(d1 / float64(n))
|
||||
d2 := x - w.m1.hi - w.m1.lo
|
||||
w.m2.sub(d1 * d2)
|
||||
}
|
||||
@@ -112,38 +117,35 @@ func (w welford2) regr_r2() float64 {
|
||||
return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi)
|
||||
}
|
||||
|
||||
func (w welford2) regr_json() string {
|
||||
var json strings.Builder
|
||||
var num [32]byte
|
||||
json.Grow(128)
|
||||
json.WriteString(`{"count":`)
|
||||
json.Write(strconv.AppendInt(num[:0], w.regr_count(), 10))
|
||||
json.WriteString(`,"avgy":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_avgy(), 'g', -1, 64))
|
||||
json.WriteString(`,"avgx":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_avgx(), 'g', -1, 64))
|
||||
json.WriteString(`,"syy":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_syy(), 'g', -1, 64))
|
||||
json.WriteString(`,"sxx":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_sxx(), 'g', -1, 64))
|
||||
json.WriteString(`,"sxy":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_sxy(), 'g', -1, 64))
|
||||
json.WriteString(`,"slope":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_slope(), 'g', -1, 64))
|
||||
json.WriteString(`,"intercept":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_intercept(), 'g', -1, 64))
|
||||
json.WriteString(`,"r2":`)
|
||||
json.Write(strconv.AppendFloat(num[:0], w.regr_r2(), 'g', -1, 64))
|
||||
json.WriteByte('}')
|
||||
return json.String()
|
||||
func (w welford2) regr_json(dst []byte) []byte {
|
||||
dst = append(dst, `{"count":`...)
|
||||
dst = strconv.AppendInt(dst, w.regr_count(), 10)
|
||||
dst = append(dst, `,"avgy":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_avgy())
|
||||
dst = append(dst, `,"avgx":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_avgx())
|
||||
dst = append(dst, `,"syy":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_syy())
|
||||
dst = append(dst, `,"sxx":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_sxx())
|
||||
dst = append(dst, `,"sxy":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_sxy())
|
||||
dst = append(dst, `,"slope":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_slope())
|
||||
dst = append(dst, `,"intercept":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_intercept())
|
||||
dst = append(dst, `,"r2":`...)
|
||||
dst = util.AppendNumber(dst, w.regr_r2())
|
||||
return append(dst, '}')
|
||||
}
|
||||
|
||||
func (w *welford2) enqueue(y, x float64) {
|
||||
w.n++
|
||||
n := w.n + 1
|
||||
w.n = n
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m1y.add(d1y / float64(w.n))
|
||||
w.m1x.add(d1x / float64(w.n))
|
||||
w.m1y.add(d1y / float64(n))
|
||||
w.m1x.add(d1x / float64(n))
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m2y.add(d1y * d2y)
|
||||
@@ -152,11 +154,16 @@ func (w *welford2) enqueue(y, x float64) {
|
||||
}
|
||||
|
||||
func (w *welford2) dequeue(y, x float64) {
|
||||
w.n--
|
||||
n := w.n - 1
|
||||
if n <= 0 {
|
||||
*w = welford2{}
|
||||
return
|
||||
}
|
||||
w.n = n
|
||||
d1y := y - w.m1y.hi - w.m1y.lo
|
||||
d1x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m1y.sub(d1y / float64(w.n))
|
||||
w.m1x.sub(d1x / float64(w.n))
|
||||
w.m1y.sub(d1y / float64(n))
|
||||
w.m1x.sub(d1x / float64(n))
|
||||
d2y := y - w.m1y.hi - w.m1y.lo
|
||||
d2x := x - w.m1x.hi - w.m1x.lo
|
||||
w.m2y.sub(d1y * d2y)
|
||||
|
||||
@@ -37,6 +37,16 @@ func Test_welford(t *testing.T) {
|
||||
if s1.var_pop() != s2.var_pop() {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
|
||||
s1.dequeue(16)
|
||||
s1.dequeue(7)
|
||||
s1.dequeue(13)
|
||||
s1.enqueue(16)
|
||||
s1.enqueue(7)
|
||||
s1.enqueue(13)
|
||||
if s1.var_pop() != s2.var_pop() {
|
||||
t.Errorf("got %v, want %v", s1, s2)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_covar(t *testing.T) {
|
||||
@@ -65,6 +75,18 @@ func Test_covar(t *testing.T) {
|
||||
if c1.covar_pop() != c2.covar_pop() {
|
||||
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
|
||||
}
|
||||
|
||||
c1.dequeue(2, 60)
|
||||
c1.dequeue(5, 80)
|
||||
c1.dequeue(4, 75)
|
||||
c1.dequeue(7, 90)
|
||||
c1.enqueue(2, 60)
|
||||
c1.enqueue(5, 80)
|
||||
c1.enqueue(4, 75)
|
||||
c1.enqueue(7, 90)
|
||||
if c1.covar_pop() != c2.covar_pop() {
|
||||
t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_correlation(t *testing.T) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strconv"
|
||||
"time"
|
||||
"unsafe"
|
||||
@@ -20,7 +21,7 @@ func (j JSON) Scan(value any) error {
|
||||
case int64:
|
||||
buf = strconv.AppendInt(nil, v, 10)
|
||||
case float64:
|
||||
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
|
||||
buf = AppendNumber(nil, v)
|
||||
case time.Time:
|
||||
buf = append(buf, '"')
|
||||
buf = v.AppendFormat(buf, time.RFC3339Nano)
|
||||
@@ -33,3 +34,17 @@ func (j JSON) Scan(value any) error {
|
||||
|
||||
return json.Unmarshal(buf, j.Value)
|
||||
}
|
||||
|
||||
func AppendNumber(dst []byte, f float64) []byte {
|
||||
switch {
|
||||
case math.IsNaN(f):
|
||||
dst = append(dst, "null"...)
|
||||
case math.IsInf(f, 1):
|
||||
dst = append(dst, "9.0e999"...)
|
||||
case math.IsInf(f, -1):
|
||||
dst = append(dst, "-9.0e999"...)
|
||||
default:
|
||||
return strconv.AppendFloat(dst, f, 'g', -1, 64)
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
2
stmt.go
2
stmt.go
@@ -609,7 +609,7 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error {
|
||||
case INTEGER:
|
||||
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
|
||||
case FLOAT:
|
||||
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
|
||||
data = util.AppendNumber(nil, s.ColumnFloat(col))
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user