From e2da469834d89a68d8478a0052906bfa85ea336e Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 20 Jan 2025 14:39:36 +0000 Subject: [PATCH] Fix numerical issues. --- ext/stats/moments.nogo | 38 ++++++++++++++++++ ext/stats/percentile.go | 3 ++ ext/stats/stats.go | 38 +++++++++++++++--- ext/stats/stats_test.go | 45 +++++++++++++++++++-- ext/stats/welford.go | 83 +++++++++++++++++++++------------------ ext/stats/welford_test.go | 22 +++++++++++ internal/util/json.go | 17 +++++++- stmt.go | 2 +- value.go | 2 +- 9 files changed, 200 insertions(+), 50 deletions(-) create mode 100644 ext/stats/moments.nogo diff --git a/ext/stats/moments.nogo b/ext/stats/moments.nogo new file mode 100644 index 0000000..31e143f --- /dev/null +++ b/ext/stats/moments.nogo @@ -0,0 +1,38 @@ +package stats + +import "math" + +type moment struct { + m1, m2, m3, m4 kahan + n int64 +} + +func (w *moment) enqueue(x float64) { + n := w.n + 1 + w.n = n + y := x - w.m1.hi - w.m1.lo + w.m1.add(y / float64(n)) + y = math.FMA(y, x, -w.m2.hi) - w.m2.lo + w.m2.add(y / float64(n)) + y = math.FMA(y, x, -w.m3.hi) - w.m3.lo + w.m3.add(y / float64(n)) + y = math.FMA(y, x, -w.m4.hi) - w.m4.lo + w.m4.add(y / float64(n)) +} + +func (w *moment) dequeue(x float64) { + n := w.n - 1 + if n <= 0 { + *w = moment{} + return + } + w.n = n + y := x - w.m1.hi + w.m1.lo + w.m1.sub(y / float64(n)) + y = math.FMA(y, x, w.m2.hi) + w.m2.lo + w.m2.sub(y / float64(n)) + y = math.FMA(y, x, w.m3.hi) + w.m3.lo + w.m3.sub(y / float64(n)) + y = math.FMA(y, x, w.m4.hi) + w.m4.lo + w.m4.sub(y / float64(n)) +} diff --git a/ext/stats/percentile.go b/ext/stats/percentile.go index 51075e4..2e63f74 100644 --- a/ext/stats/percentile.go +++ b/ext/stats/percentile.go @@ -11,6 +11,9 @@ import ( "github.com/ncruces/sort/quick" ) +// Compatible with: +// https://sqlite.org/src/file/ext/misc/percentile.c + const ( median = iota percentile_100 diff --git a/ext/stats/stats.go b/ext/stats/stats.go index 2110a52..371a390 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -17,7 +17,7 @@ // - regr_count: count non-null pairs of variables // - regr_slope: slope of the least-squares-fit linear equation // - regr_intercept: y-intercept of the least-squares-fit linear equation -// - regr_json: all regr stats in a JSON object +// - regr_json: all regr stats as a JSON object // - percentile_disc: discrete quantile // - percentile_cont: continuous quantile // - percentile: continuous percentile @@ -111,6 +111,17 @@ type variance struct { } func (fn *variance) Value(ctx sqlite3.Context) { + switch fn.n { + case 1: + switch fn.kind { + case var_pop, stddev_pop: + ctx.ResultFloat(0) + } + return + case 0: + return + } + var r float64 switch fn.kind { case var_pop: @@ -151,6 +162,25 @@ type covariance struct { } func (fn *covariance) Value(ctx sqlite3.Context) { + if fn.kind == regr_count { + ctx.ResultInt64(fn.regr_count()) + return + } + switch fn.n { + case 1: + switch fn.kind { + case var_pop, stddev_pop, regr_sxx, regr_syy, regr_sxy: + ctx.ResultFloat(0) + return + case regr_avgx, regr_avgy: + break + default: + return + } + case 0: + return + } + var r float64 switch fn.kind { case var_pop: @@ -175,11 +205,9 @@ func (fn *covariance) Value(ctx sqlite3.Context) { r = fn.regr_slope() case regr_intercept: r = fn.regr_intercept() - case regr_count: - ctx.ResultInt64(fn.regr_count()) - return case regr_json: - ctx.ResultText(fn.regr_json()) + var buf [128]byte + ctx.ResultRawText(fn.regr_json(buf[:0])) return } ctx.ResultFloat(r) diff --git a/ext/stats/stats_test.go b/ext/stats/stats_test.go index 5a5426c..fec5a7c 100644 --- a/ext/stats/stats_test.go +++ b/ext/stats/stats_test.go @@ -29,12 +29,23 @@ func TestRegister_variance(t *testing.T) { t.Fatal(err) } + stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.NULL { + t.Errorf("got %v, want NULL", got) + } + } + stmt.Close() + err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`) if err != nil { t.Fatal(err) } - stmt, _, err := db.Prepare(` + stmt, _, err = db.Prepare(` SELECT sum(x), avg(x), var_samp(x), var_pop(x), @@ -65,7 +76,11 @@ func TestRegister_variance(t *testing.T) { } stmt.Close() - stmt, _, err = db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`) + stmt, _, err = db.Prepare(` + SELECT + var_samp(x) OVER (ROWS 1 PRECEDING), + var_pop(x) OVER (ROWS 1 PRECEDING) + FROM data`) if err != nil { t.Fatal(err) } @@ -96,12 +111,26 @@ func TestRegister_covariance(t *testing.T) { t.Fatal(err) } + stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want 0", got) + } + if got := stmt.ColumnType(1); got != sqlite3.NULL { + t.Errorf("got %v, want NULL", got) + } + } + stmt.Close() + err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`) if err != nil { t.Fatal(err) } - stmt, _, err := db.Prepare(`SELECT + stmt, _, err = db.Prepare(`SELECT corr(y, x), covar_samp(y, x), covar_pop(y, x), regr_avgy(y, x), regr_avgx(y, x), regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x), @@ -157,7 +186,12 @@ func TestRegister_covariance(t *testing.T) { } stmt.Close() - stmt, _, err = db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`) + stmt, _, err = db.Prepare(` + SELECT + covar_samp(y, x) OVER (ROWS 1 PRECEDING), + covar_pop(y, x) OVER (ROWS 1 PRECEDING), + regr_avgx(y, x) OVER (ROWS 1 PRECEDING) + FROM data`) if err != nil { t.Fatal(err) } @@ -171,6 +205,9 @@ func TestRegister_covariance(t *testing.T) { t.Errorf("got %v, want %v", got, want[i]) } } + if stmt.Err() != nil { + t.Fatal(stmt.Err()) + } stmt.Close() } diff --git a/ext/stats/welford.go b/ext/stats/welford.go index d2d74ad..0469f6f 100644 --- a/ext/stats/welford.go +++ b/ext/stats/welford.go @@ -3,16 +3,15 @@ package stats import ( "math" "strconv" - "strings" + + "github.com/ncruces/go-sqlite3/internal/util" ) // Welford's algorithm with Kahan summation: +// The effect of truncation in statistical computation [van Reeken, AJ 1970] // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm // https://en.wikipedia.org/wiki/Kahan_summation_algorithm -// See also: -// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates - type welford struct { m1, m2 kahan n int64 @@ -39,17 +38,23 @@ func (w welford) stddev_samp() float64 { } func (w *welford) enqueue(x float64) { - w.n++ + n := w.n + 1 + w.n = n d1 := x - w.m1.hi - w.m1.lo - w.m1.add(d1 / float64(w.n)) + w.m1.add(d1 / float64(n)) d2 := x - w.m1.hi - w.m1.lo w.m2.add(d1 * d2) } func (w *welford) dequeue(x float64) { - w.n-- + n := w.n - 1 + if n <= 0 { + *w = welford{} + return + } + w.n = n d1 := x - w.m1.hi - w.m1.lo - w.m1.sub(d1 / float64(w.n)) + w.m1.sub(d1 / float64(n)) d2 := x - w.m1.hi - w.m1.lo w.m2.sub(d1 * d2) } @@ -112,38 +117,35 @@ func (w welford2) regr_r2() float64 { return w.cov.hi * w.cov.hi / (w.m2y.hi * w.m2x.hi) } -func (w welford2) regr_json() string { - var json strings.Builder - var num [32]byte - json.Grow(128) - json.WriteString(`{"count":`) - json.Write(strconv.AppendInt(num[:0], w.regr_count(), 10)) - json.WriteString(`,"avgy":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_avgy(), 'g', -1, 64)) - json.WriteString(`,"avgx":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_avgx(), 'g', -1, 64)) - json.WriteString(`,"syy":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_syy(), 'g', -1, 64)) - json.WriteString(`,"sxx":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_sxx(), 'g', -1, 64)) - json.WriteString(`,"sxy":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_sxy(), 'g', -1, 64)) - json.WriteString(`,"slope":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_slope(), 'g', -1, 64)) - json.WriteString(`,"intercept":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_intercept(), 'g', -1, 64)) - json.WriteString(`,"r2":`) - json.Write(strconv.AppendFloat(num[:0], w.regr_r2(), 'g', -1, 64)) - json.WriteByte('}') - return json.String() +func (w welford2) regr_json(dst []byte) []byte { + dst = append(dst, `{"count":`...) + dst = strconv.AppendInt(dst, w.regr_count(), 10) + dst = append(dst, `,"avgy":`...) + dst = util.AppendNumber(dst, w.regr_avgy()) + dst = append(dst, `,"avgx":`...) + dst = util.AppendNumber(dst, w.regr_avgx()) + dst = append(dst, `,"syy":`...) + dst = util.AppendNumber(dst, w.regr_syy()) + dst = append(dst, `,"sxx":`...) + dst = util.AppendNumber(dst, w.regr_sxx()) + dst = append(dst, `,"sxy":`...) + dst = util.AppendNumber(dst, w.regr_sxy()) + dst = append(dst, `,"slope":`...) + dst = util.AppendNumber(dst, w.regr_slope()) + dst = append(dst, `,"intercept":`...) + dst = util.AppendNumber(dst, w.regr_intercept()) + dst = append(dst, `,"r2":`...) + dst = util.AppendNumber(dst, w.regr_r2()) + return append(dst, '}') } func (w *welford2) enqueue(y, x float64) { - w.n++ + n := w.n + 1 + w.n = n d1y := y - w.m1y.hi - w.m1y.lo d1x := x - w.m1x.hi - w.m1x.lo - w.m1y.add(d1y / float64(w.n)) - w.m1x.add(d1x / float64(w.n)) + w.m1y.add(d1y / float64(n)) + w.m1x.add(d1x / float64(n)) d2y := y - w.m1y.hi - w.m1y.lo d2x := x - w.m1x.hi - w.m1x.lo w.m2y.add(d1y * d2y) @@ -152,11 +154,16 @@ func (w *welford2) enqueue(y, x float64) { } func (w *welford2) dequeue(y, x float64) { - w.n-- + n := w.n - 1 + if n <= 0 { + *w = welford2{} + return + } + w.n = n d1y := y - w.m1y.hi - w.m1y.lo d1x := x - w.m1x.hi - w.m1x.lo - w.m1y.sub(d1y / float64(w.n)) - w.m1x.sub(d1x / float64(w.n)) + w.m1y.sub(d1y / float64(n)) + w.m1x.sub(d1x / float64(n)) d2y := y - w.m1y.hi - w.m1y.lo d2x := x - w.m1x.hi - w.m1x.lo w.m2y.sub(d1y * d2y) diff --git a/ext/stats/welford_test.go b/ext/stats/welford_test.go index 93c6336..1d5e5a2 100644 --- a/ext/stats/welford_test.go +++ b/ext/stats/welford_test.go @@ -37,6 +37,16 @@ func Test_welford(t *testing.T) { if s1.var_pop() != s2.var_pop() { t.Errorf("got %v, want %v", s1, s2) } + + s1.dequeue(16) + s1.dequeue(7) + s1.dequeue(13) + s1.enqueue(16) + s1.enqueue(7) + s1.enqueue(13) + if s1.var_pop() != s2.var_pop() { + t.Errorf("got %v, want %v", s1, s2) + } } func Test_covar(t *testing.T) { @@ -65,6 +75,18 @@ func Test_covar(t *testing.T) { if c1.covar_pop() != c2.covar_pop() { t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop()) } + + c1.dequeue(2, 60) + c1.dequeue(5, 80) + c1.dequeue(4, 75) + c1.dequeue(7, 90) + c1.enqueue(2, 60) + c1.enqueue(5, 80) + c1.enqueue(4, 75) + c1.enqueue(7, 90) + if c1.covar_pop() != c2.covar_pop() { + t.Errorf("got %v, want %v", c1.covar_pop(), c2.covar_pop()) + } } func Test_correlation(t *testing.T) { diff --git a/internal/util/json.go b/internal/util/json.go index 7f6849a..8462374 100644 --- a/internal/util/json.go +++ b/internal/util/json.go @@ -2,6 +2,7 @@ package util import ( "encoding/json" + "math" "strconv" "time" "unsafe" @@ -20,7 +21,7 @@ func (j JSON) Scan(value any) error { case int64: buf = strconv.AppendInt(nil, v, 10) case float64: - buf = strconv.AppendFloat(nil, v, 'g', -1, 64) + buf = AppendNumber(nil, v) case time.Time: buf = append(buf, '"') buf = v.AppendFormat(buf, time.RFC3339Nano) @@ -33,3 +34,17 @@ func (j JSON) Scan(value any) error { return json.Unmarshal(buf, j.Value) } + +func AppendNumber(dst []byte, f float64) []byte { + switch { + case math.IsNaN(f): + dst = append(dst, "null"...) + case math.IsInf(f, 1): + dst = append(dst, "9.0e999"...) + case math.IsInf(f, -1): + dst = append(dst, "-9.0e999"...) + default: + return strconv.AppendFloat(dst, f, 'g', -1, 64) + } + return dst +} diff --git a/stmt.go b/stmt.go index 47ef0d2..fdb13dc 100644 --- a/stmt.go +++ b/stmt.go @@ -609,7 +609,7 @@ func (s *Stmt) ColumnJSON(col int, ptr any) error { case INTEGER: data = strconv.AppendInt(nil, s.ColumnInt64(col), 10) case FLOAT: - data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64) + data = util.AppendNumber(nil, s.ColumnFloat(col)) default: panic(util.AssertErr()) } diff --git a/value.go b/value.go index 86f6689..43b1a0f 100644 --- a/value.go +++ b/value.go @@ -185,7 +185,7 @@ func (v Value) JSON(ptr any) error { case INTEGER: data = strconv.AppendInt(nil, v.Int64(), 10) case FLOAT: - data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64) + data = util.AppendNumber(nil, v.Float()) default: panic(util.AssertErr()) }