More stats.

This commit is contained in:
Nuno Cruces
2024-01-10 11:07:16 +00:00
parent af42af2978
commit ee48dd5c96
4 changed files with 91 additions and 29 deletions

View File

@@ -2,7 +2,7 @@
https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
## Built in
## Built in aggregates
- [x] `COUNT(*)`
- [x] `COUNT(expression)`
@@ -13,7 +13,7 @@ https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
https://sqlite.org/lang_aggfunc.html
## Implemented
## Statistical aggregates
- [x] `STDDEV_POP(expression)`
- [x] `STDDEV_SAMP(expression)`
@@ -27,15 +27,15 @@ https://sqlite.org/lang_aggfunc.html
- [X] `REGR_AVGX(dependent, independent)`
- [X] `REGR_AVGY(dependent, independent)`
- [ ] `REGR_COUNT(dependent, independent)`
- [X] `REGR_SXX(dependent, independent)`
- [X] `REGR_SYY(dependent, independent)`
- [X] `REGR_SXY(dependent, independent)`
- [X] `REGR_COUNT(dependent, independent)`
- [X] `REGR_SLOPE(dependent, independent)`
- [X] `REGR_INTERCEPT(dependent, independent)`
- [X] `REGR_R2(dependent, independent)`
- [X] `REGR_SLOPE(dependent, independent)`
- [ ] `REGR_SXX(dependent, independent)`
- [ ] `REGR_SXY(dependent, independent)`
- [ ] `REGR_SYY(dependent, independent)`
## Other
## Ordered set aggregates
- [ ] `CUME_DIST(value_list) WITHIN GROUP (ORDER BY sort_list)`
- [ ] `RANK(value_list) WITHIN GROUP (ORDER BY sort_list)`

View File

@@ -8,6 +8,15 @@
// - covar_pop: population covariance
// - covar_samp: sample covariance
// - corr: correlation coefficient
// - regr_r2: correlation coefficient squared
// - regr_avgx: average of the independent variable
// - regr_avgy: average of the dependent variable
// - regr_sxx: sum of the squares of the independent variable
// - regr_syy: sum of the squares of the dependent variable
// - regr_sxy: sum of the products of each pair of variables
// - regr_count: count non-null pairs of variables
// - regr_slope: slope of the least-squares-fit linear equation
// - regr_intercept: y-intercept of the least-squares-fit linear equation
//
// These join the [Built-in Aggregate Functions]:
// - count: count rows/values
@@ -34,11 +43,15 @@ func Register(db *sqlite3.Conn) {
db.CreateWindowFunction("covar_pop", 2, flags, newCovariance(var_pop))
db.CreateWindowFunction("covar_samp", 2, flags, newCovariance(var_samp))
db.CreateWindowFunction("corr", 2, flags, newCovariance(corr))
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2))
db.CreateWindowFunction("regr_sxx", 2, flags, newCovariance(regr_sxx))
db.CreateWindowFunction("regr_syy", 2, flags, newCovariance(regr_syy))
db.CreateWindowFunction("regr_sxy", 2, flags, newCovariance(regr_sxy))
db.CreateWindowFunction("regr_avgx", 2, flags, newCovariance(regr_avgx))
db.CreateWindowFunction("regr_avgy", 2, flags, newCovariance(regr_avgy))
db.CreateWindowFunction("regr_r2", 2, flags, newCovariance(regr_r2))
db.CreateWindowFunction("regr_slope", 2, flags, newCovariance(regr_slope))
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
}
const (
@@ -47,11 +60,15 @@ const (
stddev_pop
stddev_samp
corr
regr_r2
regr_sxx
regr_syy
regr_sxy
regr_avgx
regr_avgy
regr_r2
regr_slope
regr_intercept
regr_count
)
func newVariance(kind int) func() sqlite3.AggregateFunction {
@@ -108,16 +125,25 @@ func (fn *covariance) Value(ctx sqlite3.Context) {
r = fn.covar_samp()
case corr:
r = fn.correlation()
case regr_r2:
r = fn.regr_r2()
case regr_sxx:
r = fn.regr_sxx()
case regr_syy:
r = fn.regr_syy()
case regr_sxy:
r = fn.regr_sxy()
case regr_avgx:
r = fn.regr_avgx()
case regr_avgy:
r = fn.regr_avgy()
case regr_r2:
r = fn.regr_r2()
case regr_slope:
r = fn.regr_slope()
case regr_intercept:
r = fn.regr_intercept()
case regr_count:
ctx.ResultInt64(fn.regr_count())
return
}
ctx.ResultFloat(r)
}

View File

@@ -92,20 +92,22 @@ func TestRegister_covariance(t *testing.T) {
stats.Register(db)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x, y)`)
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (y, x)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x, y) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
corr(x, y), covar_samp(x, y), covar_pop(x, y),
regr_avgx(x, y), regr_avgy(x, y), regr_r2(x, y),
regr_slope(x, y), regr_intercept(x, y)
corr(y, x), covar_samp(y, x), covar_pop(y, x),
regr_avgy(y, x), regr_avgx(y, x),
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x),
regr_count(y, x)
FROM data`)
if err != nil {
t.Fatal(err)
@@ -122,25 +124,37 @@ func TestRegister_covariance(t *testing.T) {
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
if got := stmt.ColumnFloat(3); got != 75 {
t.Errorf("got %v, want 75", got)
}
if got := stmt.ColumnFloat(4); got != 4.2 {
if got := stmt.ColumnFloat(3); got != 4.2 {
t.Errorf("got %v, want 4.2", got)
}
if got := stmt.ColumnFloat(5); got != 0.9763513513513513 {
t.Errorf("got %v, want 0.9763513513513513", got)
if got := stmt.ColumnFloat(4); got != 75 {
t.Errorf("got %v, want 75", got)
}
if got := stmt.ColumnFloat(6); got != 0.17 {
if got := stmt.ColumnFloat(5); got != 14.8 {
t.Errorf("got %v, want 14.8", got)
}
if got := stmt.ColumnFloat(6); got != 500 {
t.Errorf("got %v, want 500", got)
}
if got := stmt.ColumnFloat(7); got != 85 {
t.Errorf("got %v, want 85", got)
}
if got := stmt.ColumnFloat(8); got != 0.17 {
t.Errorf("got %v, want 0.17", got)
}
if got := stmt.ColumnFloat(7); got != -8.55 {
if got := stmt.ColumnFloat(9); got != -8.55 {
t.Errorf("got %v, want -8.55", got)
}
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
t.Errorf("got %v, want 0.9763513513513513", got)
}
if got := stmt.ColumnInt(11); got != 5 {
t.Errorf("got %v, want 5", got)
}
}
{
stmt, _, err := db.Prepare(`SELECT covar_samp(x, y) OVER (ROWS 1 PRECEDING) FROM data`)
stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
if err != nil {
t.Fatal(err)
}

View File

@@ -6,9 +6,12 @@ import "math"
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm
// See also:
// https://duckdb.org/docs/sql/aggregates.html#statistical-aggregates
type welford struct {
m1, m2 kahan
n uint64
n int64
}
func (w welford) average() float64 {
@@ -51,7 +54,7 @@ type welford2 struct {
m1y, m2y kahan
m1x, m2x kahan
cov kahan
n uint64
n int64
}
func (w welford2) covar_pop() float64 {
@@ -74,12 +77,31 @@ func (w welford2) regr_avgx() float64 {
return w.m1x.hi
}
func (w welford2) regr_syy() float64 {
return w.m2y.hi
}
func (w welford2) regr_sxx() float64 {
return w.m2x.hi
}
func (w welford2) regr_sxy() float64 {
return w.cov.hi
}
func (w welford2) regr_count() int64 {
return w.n
}
func (w welford2) regr_slope() float64 {
return w.cov.hi / w.m2x.hi
}
func (w welford2) regr_intercept() float64 {
return w.m1y.hi - w.m1x.hi*w.regr_slope()
slope := -w.regr_slope()
hi := math.FMA(slope, w.m1x.hi, w.m1y.hi)
lo := math.FMA(slope, w.m1x.lo, w.m1y.lo)
return hi + lo
}
func (w welford2) regr_r2() float64 {