From d78a53a789d899987b0ab4ea59f993ab8e16a16b Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sun, 2 Jun 2024 13:37:29 +0100 Subject: [PATCH] Multiple quantiles. --- ext/stats/quantile.go | 70 +++++++++++++++++++++++++++----------- ext/stats/quantile_test.go | 29 +++++++++++----- 2 files changed, 70 insertions(+), 29 deletions(-) diff --git a/ext/stats/quantile.go b/ext/stats/quantile.go index 77d4217..bcda239 100644 --- a/ext/stats/quantile.go +++ b/ext/stats/quantile.go @@ -1,6 +1,8 @@ package stats import ( + "encoding/json" + "fmt" "math" "slices" @@ -21,39 +23,67 @@ func newQuantile(kind int) func() sqlite3.AggregateFunction { type quantile struct { kind int - pos float64 - list []float64 + nums []float64 + arg1 []byte } func (q *quantile) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { if a := arg[0]; a.NumericType() != sqlite3.NULL { - q.list = append(q.list, a.Float()) + q.nums = append(q.nums, a.Float()) } if q.kind != median { - q.pos = arg[1].Float() + q.arg1 = arg[1].Blob(q.arg1[:0]) } } func (q *quantile) Value(ctx sqlite3.Context) { - if len(q.list) == 0 { + if len(q.nums) == 0 { return } + + var ( + err error + float float64 + floats []float64 + ) if q.kind == median { - q.pos = 0.5 + float, err = getQuantile(q.nums, 0.5, false) + ctx.ResultFloat(float) + } else if err = json.Unmarshal(q.arg1, &float); err == nil { + float, err = getQuantile(q.nums, float, q.kind == quant_disc) + ctx.ResultFloat(float) + } else if err = json.Unmarshal(q.arg1, &floats); err == nil { + err = getQuantiles(q.nums, floats, q.kind == quant_disc) + ctx.ResultJSON(floats) } - if q.pos < 0 || q.pos > 1 { - ctx.ResultError(util.ErrorString("quantile: invalid pos")) - return + if err != nil { + ctx.ResultError(fmt.Errorf("quantile: %w", err)) } - - i, f := math.Modf(q.pos * float64(len(q.list)-1)) - m0 := quick.Select(q.list, int(i)) - - if f == 0 || q.kind == quant_disc { - ctx.ResultFloat(m0) - return - } - - m1 := slices.Min(q.list[int(i)+1:]) - ctx.ResultFloat(math.FMA(f, m1, -math.FMA(f, m0, -m0))) +} + +func getQuantile(nums []float64, pos float64, disc bool) (float64, error) { + if pos < 0 || pos > 1 { + return 0, util.ErrorString("invalid pos") + } + + i, f := math.Modf(pos * float64(len(nums)-1)) + m0 := quick.Select(nums, int(i)) + + if f == 0 || disc { + return m0, nil + } + + m1 := slices.Min(nums[int(i)+1:]) + return math.FMA(f, m1, -math.FMA(f, m0, -m0)), nil +} + +func getQuantiles(nums []float64, pos []float64, disc bool) error { + for i := range pos { + v, err := getQuantile(nums, pos[i], disc) + if err != nil { + return err + } + pos[i] = v + } + return nil } diff --git a/ext/stats/quantile_test.go b/ext/stats/quantile_test.go index f43e9c3..eb83043 100644 --- a/ext/stats/quantile_test.go +++ b/ext/stats/quantile_test.go @@ -1,6 +1,7 @@ package stats_test import ( + "slices" "testing" "github.com/ncruces/go-sqlite3" @@ -34,7 +35,7 @@ func TestRegister_quantile(t *testing.T) { SELECT median(x), quantile_disc(x, 0.5), - quantile_cont(x, 0.25) + quantile_cont(x, '[0.25, 0.5, 0.75]') FROM data`) if err != nil { t.Fatal(err) @@ -46,8 +47,12 @@ func TestRegister_quantile(t *testing.T) { if got := stmt.ColumnFloat(1); got != 7 { t.Errorf("got %v, want 7", got) } - if got := stmt.ColumnFloat(2); got != 6.25 { - t.Errorf("got %v, want 6.25", got) + var got []float64 + if err := stmt.ColumnJSON(2, &got); err != nil { + t.Error(err) + } + if !slices.Equal(got, []float64{6.25, 10, 13.75}) { + t.Errorf("got %v, want [6.25 10 13.75]", got) } } stmt.Close() @@ -56,7 +61,7 @@ func TestRegister_quantile(t *testing.T) { SELECT median(x), quantile_disc(x, 0.5), - quantile_cont(x, 0.25) + quantile_cont(x, '[0.25, 0.5, 0.75]') FROM data WHERE x < 5`) if err != nil { @@ -69,8 +74,12 @@ func TestRegister_quantile(t *testing.T) { if got := stmt.ColumnFloat(1); got != 4 { t.Errorf("got %v, want 4", got) } - if got := stmt.ColumnFloat(2); got != 4 { - t.Errorf("got %v, want 4", got) + var got []float64 + if err := stmt.ColumnJSON(2, &got); err != nil { + t.Error(err) + } + if !slices.Equal(got, []float64{4, 4, 4}) { + t.Errorf("got %v, want [4 4 4]", got) } } stmt.Close() @@ -79,7 +88,7 @@ func TestRegister_quantile(t *testing.T) { SELECT median(x), quantile_disc(x, 0.5), - quantile_cont(x, 0.25) + quantile_cont(x, '[0.25, 0.5, 0.75]') FROM data WHERE x < 0`) if err != nil { @@ -101,13 +110,15 @@ func TestRegister_quantile(t *testing.T) { stmt, _, err = db.Prepare(` SELECT quantile_disc(x, -2), - quantile_cont(x, +2) + quantile_cont(x, +2), + quantile_cont(x, ''), + quantile_cont(x, '[100]') FROM data`) if err != nil { t.Fatal(err) } if stmt.Step() { - t.Fatal("want error") + t.Error("want error") } stmt.Close() }