mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Multiple quantiles.
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
@@ -21,39 +23,67 @@ func newQuantile(kind int) func() sqlite3.AggregateFunction {
|
||||
|
||||
type quantile struct {
|
||||
kind int
|
||||
pos float64
|
||||
list []float64
|
||||
nums []float64
|
||||
arg1 []byte
|
||||
}
|
||||
|
||||
func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
if a := arg[0]; a.NumericType() != sqlite3.NULL {
|
||||
q.list = append(q.list, a.Float())
|
||||
q.nums = append(q.nums, a.Float())
|
||||
}
|
||||
if q.kind != median {
|
||||
q.pos = arg[1].Float()
|
||||
q.arg1 = arg[1].Blob(q.arg1[:0])
|
||||
}
|
||||
}
|
||||
|
||||
func (q *quantile) Value(ctx sqlite3.Context) {
|
||||
if len(q.list) == 0 {
|
||||
if len(q.nums) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
err error
|
||||
float float64
|
||||
floats []float64
|
||||
)
|
||||
if q.kind == median {
|
||||
q.pos = 0.5
|
||||
float, err = getQuantile(q.nums, 0.5, false)
|
||||
ctx.ResultFloat(float)
|
||||
} else if err = json.Unmarshal(q.arg1, &float); err == nil {
|
||||
float, err = getQuantile(q.nums, float, q.kind == quant_disc)
|
||||
ctx.ResultFloat(float)
|
||||
} else if err = json.Unmarshal(q.arg1, &floats); err == nil {
|
||||
err = getQuantiles(q.nums, floats, q.kind == quant_disc)
|
||||
ctx.ResultJSON(floats)
|
||||
}
|
||||
if q.pos < 0 || q.pos > 1 {
|
||||
ctx.ResultError(util.ErrorString("quantile: invalid pos"))
|
||||
return
|
||||
if err != nil {
|
||||
ctx.ResultError(fmt.Errorf("quantile: %w", err))
|
||||
}
|
||||
|
||||
i, f := math.Modf(q.pos * float64(len(q.list)-1))
|
||||
m0 := quick.Select(q.list, int(i))
|
||||
|
||||
if f == 0 || q.kind == quant_disc {
|
||||
ctx.ResultFloat(m0)
|
||||
return
|
||||
}
|
||||
|
||||
m1 := slices.Min(q.list[int(i)+1:])
|
||||
ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0)))
|
||||
}
|
||||
|
||||
func getQuantile(nums []float64, pos float64, disc bool) (float64, error) {
|
||||
if pos < 0 || pos > 1 {
|
||||
return 0, util.ErrorString("invalid pos")
|
||||
}
|
||||
|
||||
i, f := math.Modf(pos * float64(len(nums)-1))
|
||||
m0 := quick.Select(nums, int(i))
|
||||
|
||||
if f == 0 || disc {
|
||||
return m0, nil
|
||||
}
|
||||
|
||||
m1 := slices.Min(nums[int(i)+1:])
|
||||
return math.FMA(f, m1, -math.FMA(f, m0, -m0)), nil
|
||||
}
|
||||
|
||||
func getQuantiles(nums []float64, pos []float64, disc bool) error {
|
||||
for i := range pos {
|
||||
v, err := getQuantile(nums, pos[i], disc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pos[i] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
@@ -34,7 +35,7 @@ func TestRegister_quantile(t *testing.T) {
|
||||
SELECT
|
||||
median(x),
|
||||
quantile_disc(x, 0.5),
|
||||
quantile_cont(x, 0.25)
|
||||
quantile_cont(x, '[0.25, 0.5, 0.75]')
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -46,8 +47,12 @@ func TestRegister_quantile(t *testing.T) {
|
||||
if got := stmt.ColumnFloat(1); got != 7 {
|
||||
t.Errorf("got %v, want 7", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 6.25 {
|
||||
t.Errorf("got %v, want 6.25", got)
|
||||
var got []float64
|
||||
if err := stmt.ColumnJSON(2, &got); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !slices.Equal(got, []float64{6.25, 10, 13.75}) {
|
||||
t.Errorf("got %v, want [6.25 10 13.75]", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
@@ -56,7 +61,7 @@ func TestRegister_quantile(t *testing.T) {
|
||||
SELECT
|
||||
median(x),
|
||||
quantile_disc(x, 0.5),
|
||||
quantile_cont(x, 0.25)
|
||||
quantile_cont(x, '[0.25, 0.5, 0.75]')
|
||||
FROM data
|
||||
WHERE x < 5`)
|
||||
if err != nil {
|
||||
@@ -69,8 +74,12 @@ func TestRegister_quantile(t *testing.T) {
|
||||
if got := stmt.ColumnFloat(1); got != 4 {
|
||||
t.Errorf("got %v, want 4", got)
|
||||
}
|
||||
if got := stmt.ColumnFloat(2); got != 4 {
|
||||
t.Errorf("got %v, want 4", got)
|
||||
var got []float64
|
||||
if err := stmt.ColumnJSON(2, &got); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if !slices.Equal(got, []float64{4, 4, 4}) {
|
||||
t.Errorf("got %v, want [4 4 4]", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
@@ -79,7 +88,7 @@ func TestRegister_quantile(t *testing.T) {
|
||||
SELECT
|
||||
median(x),
|
||||
quantile_disc(x, 0.5),
|
||||
quantile_cont(x, 0.25)
|
||||
quantile_cont(x, '[0.25, 0.5, 0.75]')
|
||||
FROM data
|
||||
WHERE x < 0`)
|
||||
if err != nil {
|
||||
@@ -101,13 +110,15 @@ func TestRegister_quantile(t *testing.T) {
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT
|
||||
quantile_disc(x, -2),
|
||||
quantile_cont(x, +2)
|
||||
quantile_cont(x, +2),
|
||||
quantile_cont(x, ''),
|
||||
quantile_cont(x, '[100]')
|
||||
FROM data`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
t.Fatal("want error")
|
||||
t.Error("want error")
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user