Multiple quantiles.

This commit is contained in:
Nuno Cruces
2024-06-02 13:37:29 +01:00
parent 19bc6e3fac
commit d78a53a789
2 changed files with 70 additions and 29 deletions

View File

@@ -1,6 +1,8 @@
package stats
import (
"encoding/json"
"fmt"
"math"
"slices"
@@ -21,39 +23,67 @@ func newQuantile(kind int) func() sqlite3.AggregateFunction {
type quantile struct {
kind int
pos float64
list []float64
nums []float64
arg1 []byte
}
func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
if a := arg[0]; a.NumericType() != sqlite3.NULL {
q.list = append(q.list, a.Float())
q.nums = append(q.nums, a.Float())
}
if q.kind != median {
q.pos = arg[1].Float()
q.arg1 = arg[1].Blob(q.arg1[:0])
}
}
func (q *quantile) Value(ctx sqlite3.Context) {
if len(q.list) == 0 {
if len(q.nums) == 0 {
return
}
var (
err error
float float64
floats []float64
)
if q.kind == median {
q.pos = 0.5
float, err = getQuantile(q.nums, 0.5, false)
ctx.ResultFloat(float)
} else if err = json.Unmarshal(q.arg1, &float); err == nil {
float, err = getQuantile(q.nums, float, q.kind == quant_disc)
ctx.ResultFloat(float)
} else if err = json.Unmarshal(q.arg1, &floats); err == nil {
err = getQuantiles(q.nums, floats, q.kind == quant_disc)
ctx.ResultJSON(floats)
}
if q.pos < 0 || q.pos > 1 {
ctx.ResultError(util.ErrorString("quantile: invalid pos"))
return
if err != nil {
ctx.ResultError(fmt.Errorf("quantile: %w", err))
}
i, f := math.Modf(q.pos * float64(len(q.list)-1))
m0 := quick.Select(q.list, int(i))
if f == 0 || q.kind == quant_disc {
ctx.ResultFloat(m0)
return
}
m1 := slices.Min(q.list[int(i)+1:])
ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0)))
}
func getQuantile(nums []float64, pos float64, disc bool) (float64, error) {
if pos < 0 || pos > 1 {
return 0, util.ErrorString("invalid pos")
}
i, f := math.Modf(pos * float64(len(nums)-1))
m0 := quick.Select(nums, int(i))
if f == 0 || disc {
return m0, nil
}
m1 := slices.Min(nums[int(i)+1:])
return math.FMA(f, m1, -math.FMA(f, m0, -m0)), nil
}
func getQuantiles(nums []float64, pos []float64, disc bool) error {
for i := range pos {
v, err := getQuantile(nums, pos[i], disc)
if err != nil {
return err
}
pos[i] = v
}
return nil
}

View File

@@ -1,6 +1,7 @@
package stats_test
import (
"slices"
"testing"
"github.com/ncruces/go-sqlite3"
@@ -34,7 +35,7 @@ func TestRegister_quantile(t *testing.T) {
SELECT
median(x),
quantile_disc(x, 0.5),
quantile_cont(x, 0.25)
quantile_cont(x, '[0.25, 0.5, 0.75]')
FROM data`)
if err != nil {
t.Fatal(err)
@@ -46,8 +47,12 @@ func TestRegister_quantile(t *testing.T) {
if got := stmt.ColumnFloat(1); got != 7 {
t.Errorf("got %v, want 7", got)
}
if got := stmt.ColumnFloat(2); got != 6.25 {
t.Errorf("got %v, want 6.25", got)
var got []float64
if err := stmt.ColumnJSON(2, &got); err != nil {
t.Error(err)
}
if !slices.Equal(got, []float64{6.25, 10, 13.75}) {
t.Errorf("got %v, want [6.25 10 13.75]", got)
}
}
stmt.Close()
@@ -56,7 +61,7 @@ func TestRegister_quantile(t *testing.T) {
SELECT
median(x),
quantile_disc(x, 0.5),
quantile_cont(x, 0.25)
quantile_cont(x, '[0.25, 0.5, 0.75]')
FROM data
WHERE x < 5`)
if err != nil {
@@ -69,8 +74,12 @@ func TestRegister_quantile(t *testing.T) {
if got := stmt.ColumnFloat(1); got != 4 {
t.Errorf("got %v, want 4", got)
}
if got := stmt.ColumnFloat(2); got != 4 {
t.Errorf("got %v, want 4", got)
var got []float64
if err := stmt.ColumnJSON(2, &got); err != nil {
t.Error(err)
}
if !slices.Equal(got, []float64{4, 4, 4}) {
t.Errorf("got %v, want [4 4 4]", got)
}
}
stmt.Close()
@@ -79,7 +88,7 @@ func TestRegister_quantile(t *testing.T) {
SELECT
median(x),
quantile_disc(x, 0.5),
quantile_cont(x, 0.25)
quantile_cont(x, '[0.25, 0.5, 0.75]')
FROM data
WHERE x < 0`)
if err != nil {
@@ -101,13 +110,15 @@ func TestRegister_quantile(t *testing.T) {
stmt, _, err = db.Prepare(`
SELECT
quantile_disc(x, -2),
quantile_cont(x, +2)
quantile_cont(x, +2),
quantile_cont(x, ''),
quantile_cont(x, '[100]')
FROM data`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
t.Fatal("want error")
t.Error("want error")
}
stmt.Close()
}