Quantiles.

This commit is contained in:
Nuno Cruces
2024-05-31 17:36:16 +01:00
parent dbf93b2171
commit fa7516ce30
6 changed files with 130 additions and 0 deletions

59
ext/stats/quantile.go Normal file
View File

@@ -0,0 +1,59 @@
package stats
import (
"math"
"slices"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/sort/quick"
)
const (
median = iota
quant_cont
quant_disc
)
func newQuantile(kind int) func() sqlite3.AggregateFunction {
return func() sqlite3.AggregateFunction { return &quantile{kind: kind} }
}
type quantile struct {
kind int
pos float64
list []float64
}
func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.NumericType() != sqlite3.NULL {
q.list = append(q.list, a.Float())
}
if q.kind != median {
q.pos = arg[1].Float()
}
}
func (q *quantile) Value(ctx sqlite3.Context) {
if q.list == nil {
return
}
if q.kind == median {
q.pos = 0.5
}
if q.pos < 0 || q.pos > 1 {
ctx.ResultError(util.ErrorString("quantile: invalid pos"))
return
}
i, f := math.Modf(q.pos * float64(len(q.list)-1))
m0 := quick.Select(q.list, int(i))
if q.kind == quant_disc {
ctx.ResultFloat(m0)
return
}
m1 := slices.Min(q.list[int(i)+1:])
ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0)))
}

View File

@@ -0,0 +1,55 @@
package stats_test
import (
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/ext/stats"
_ "github.com/ncruces/go-sqlite3/tests/testcfg"
)
func TestRegister_quantile(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stats.Register(db)
err = db.Exec(`CREATE TABLE data (x)`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`
SELECT
median(x),
quantile_disc(x, 0.5),
quantile_cont(x, 0.3)
FROM data`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(1); got != 7 {
t.Errorf("got %v, want 7", got)
}
if got := stmt.ColumnFloat(2); got != 6.699999999999999 {
t.Errorf("got %v, want 6.7", got)
}
}
}

View File

@@ -18,6 +18,9 @@
// - regr_slope: slope of the least-squares-fit linear equation
// - regr_intercept: y-intercept of the least-squares-fit linear equation
// - regr_json: all regr stats in a JSON object
// - median: median value
// - quantile_cont: continuous quantile
// - quantile_disc: discrete quantile
//
// These join the [Built-in Aggregate Functions]:
// - count: count rows/values
@@ -27,9 +30,11 @@
// - max: maximum value
//
// See: [ANSI SQL Aggregate Functions]
// See: [DuckDB Aggregate Functions]
//
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
// [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html
package stats
import "github.com/ncruces/go-sqlite3"
@@ -54,6 +59,9 @@ func Register(db *sqlite3.Conn) {
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json))
db.CreateWindowFunction("median", 1, flags, newQuantile(median))
db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont))
db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc))
}
const (

2
go.mod
View File

@@ -13,4 +13,6 @@ require (
lukechampine.com/adiantum v1.1.1
)
require github.com/ncruces/sort v0.1.2
retract v0.4.0 // tagged from the wrong branch

2
go.sum
View File

@@ -1,5 +1,7 @@
github.com/ncruces/julianday v1.0.0 h1:fH0OKwa7NWvniGQtxdJRxAgkBMolni2BjDHaWTxqt7M=
github.com/ncruces/julianday v1.0.0/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/ncruces/sort v0.1.2 h1:zKQ9CA4fpHPF6xsUhRTfi5EEryspuBpe/QA4VWQOV1U=
github.com/ncruces/sort v0.1.2/go.mod h1:vEJUTBJtebIuCMmXD18GKo5GJGhsay+xZFOoBEIXFmE=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
github.com/psanford/httpreadat v0.1.0/go.mod h1:Zg7P+TlBm3bYbyHTKv/EdtSJZn3qwbPwpfZ/I9GKCRE=
github.com/tetratelabs/wazero v1.7.2 h1:1+z5nXJNwMLPAWaTePFi49SSTL0IMx/i3Fg8Yc25GDc=

4
go.work.sum Normal file
View File

@@ -0,0 +1,4 @@
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=