mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Quantiles.
This commit is contained in:
59
ext/stats/quantile.go
Normal file
59
ext/stats/quantile.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/internal/util"
|
||||
"github.com/ncruces/sort/quick"
|
||||
)
|
||||
|
||||
const (
|
||||
median = iota
|
||||
quant_cont
|
||||
quant_disc
|
||||
)
|
||||
|
||||
func newQuantile(kind int) func() sqlite3.AggregateFunction {
|
||||
return func() sqlite3.AggregateFunction { return &quantile{kind: kind} }
|
||||
}
|
||||
|
||||
type quantile struct {
|
||||
kind int
|
||||
pos float64
|
||||
list []float64
|
||||
}
|
||||
|
||||
func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.NumericType() != sqlite3.NULL {
|
||||
q.list = append(q.list, a.Float())
|
||||
}
|
||||
if q.kind != median {
|
||||
q.pos = arg[1].Float()
|
||||
}
|
||||
}
|
||||
|
||||
func (q *quantile) Value(ctx sqlite3.Context) {
|
||||
if q.list == nil {
|
||||
return
|
||||
}
|
||||
if q.kind == median {
|
||||
q.pos = 0.5
|
||||
}
|
||||
if q.pos < 0 || q.pos > 1 {
|
||||
ctx.ResultError(util.ErrorString("quantile: invalid pos"))
|
||||
return
|
||||
}
|
||||
|
||||
i, f := math.Modf(q.pos * float64(len(q.list)-1))
|
||||
m0 := quick.Select(q.list, int(i))
|
||||
|
||||
if q.kind == quant_disc {
|
||||
ctx.ResultFloat(m0)
|
||||
return
|
||||
}
|
||||
|
||||
m1 := slices.Min(q.list[int(i)+1:])
|
||||
ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0)))
|
||||
}
|
||||
55
ext/stats/quantile_test.go
Normal file
55
ext/stats/quantile_test.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/stats"
|
||||
_ "github.com/ncruces/go-sqlite3/tests/testcfg"
|
||||
)
|
||||
|
||||
func TestRegister_quantile(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stats.Register(db)
|
||||
|
||||
err = db.Exec(`CREATE TABLE data (x)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stmt, _, err := db.Prepare(`
|
||||
SELECT
|
||||
median(x),
|
||||
quantile_disc(x, 0.5),
|
||||
quantile_cont(x, 0.3)
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 10 {
|
||||
t.Errorf("got %v, want 10", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(1); got != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 6.699999999999999 {
|
||||
t.Errorf("got %v, want 6.7", got)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,9 @@
|
||||
// - regr_slope: slope of the least-squares-fit linear equation
|
||||
// - regr_intercept: y-intercept of the least-squares-fit linear equation
|
||||
// - regr_json: all regr stats in a JSON object
|
||||
// - median: median value
|
||||
// - quantile_cont: continuous quantile
|
||||
// - quantile_disc: discrete quantile
|
||||
//
|
||||
// These join the [Built-in Aggregate Functions]:
|
||||
// - count: count rows/values
|
||||
@@ -27,9 +30,11 @@
|
||||
// - max: maximum value
|
||||
//
|
||||
// See: [ANSI SQL Aggregate Functions]
|
||||
// See: [DuckDB Aggregate Functions]
|
||||
//
|
||||
// [Built-in Aggregate Functions]: https://sqlite.org/lang_aggfunc.html
|
||||
// [ANSI SQL Aggregate Functions]: https://www.oreilly.com/library/view/sql-in-a/9780596155322/ch04s02.html
|
||||
// [DuckDB Aggregate Functions]: https://duckdb.org/docs/sql/aggregates.html
|
||||
package stats
|
||||
|
||||
import "github.com/ncruces/go-sqlite3"
|
||||
@@ -54,6 +59,9 @@ func Register(db *sqlite3.Conn) {
|
||||
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept))
|
||||
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count))
|
||||
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json))
|
||||
db.CreateWindowFunction("median", 1, flags, newQuantile(median))
|
||||
db.CreateWindowFunction("quantile_cont", 2, flags, newQuantile(quant_cont))
|
||||
db.CreateWindowFunction("quantile_disc", 2, flags, newQuantile(quant_disc))
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
Reference in New Issue
Block a user