Better percentile compatibility.

This commit is contained in:
Nuno Cruces
2024-10-16 14:00:22 +01:00
parent ddfaf12cd8
commit c78d00dca0
4 changed files with 32 additions and 13 deletions

View File

@@ -177,6 +177,7 @@ const (
DETERMINISTIC FunctionFlag = 0x000000800
DIRECTONLY FunctionFlag = 0x000080000
INNOCUOUS FunctionFlag = 0x000200000
SELFORDER1 FunctionFlag = 0x002000000
// SUBTYPE FunctionFlag = 0x000100000
// RESULT_SUBTYPE FunctionFlag = 0x001000000
)

View File

@@ -13,6 +13,7 @@ import (
const (
median = iota
percentile_100
percentile_cont
percentile_disc
)
@@ -54,13 +55,13 @@ func (q *percentile) Value(ctx sqlite3.Context) {
floats []float64
)
if q.kind == median {
float, err = getPercentile(q.nums, 0.5, false)
float, err = getPercentile(q.nums, 0.5, q.kind)
ctx.ResultFloat(float)
} else if err = json.Unmarshal(q.arg1, &float); err == nil {
float, err = getPercentile(q.nums, float, q.kind == percentile_disc)
float, err = getPercentile(q.nums, float, q.kind)
ctx.ResultFloat(float)
} else if err = json.Unmarshal(q.arg1, &floats); err == nil {
err = getPercentiles(q.nums, floats, q.kind == percentile_disc)
err = getPercentiles(q.nums, floats, q.kind)
ctx.ResultJSON(floats)
}
if err != nil {
@@ -68,7 +69,10 @@ func (q *percentile) Value(ctx sqlite3.Context) {
}
}
func getPercentile(nums []float64, pos float64, disc bool) (float64, error) {
func getPercentile(nums []float64, pos float64, kind int) (float64, error) {
if kind == percentile_100 {
pos = pos / 100
}
if pos < 0 || pos > 1 {
return 0, util.ErrorString("invalid pos")
}
@@ -76,7 +80,7 @@ func getPercentile(nums []float64, pos float64, disc bool) (float64, error) {
i, f := math.Modf(pos * float64(len(nums)-1))
m0 := quick.Select(nums, int(i))
if f == 0 || disc {
if f == 0 || kind == percentile_disc {
return m0, nil
}
@@ -84,9 +88,9 @@ func getPercentile(nums []float64, pos float64, disc bool) (float64, error) {
return math.FMA(f, m1, math.FMA(-f, m0, m0)), nil
}
func getPercentiles(nums []float64, pos []float64, disc bool) error {
func getPercentiles(nums []float64, pos []float64, kind int) error {
for i := range pos {
v, err := getPercentile(nums, pos[i], disc)
v, err := getPercentile(nums, pos[i], kind)
if err != nil {
return err
}

View File

@@ -31,6 +31,7 @@ func TestRegister_percentile(t *testing.T) {
stmt, _, err := db.Prepare(`
SELECT
median(x),
percentile(x, 50),
percentile_disc(x, 0.5),
percentile_cont(x, '[0.25, 0.5, 0.75]')
FROM data`)
@@ -41,11 +42,14 @@ func TestRegister_percentile(t *testing.T) {
if got := stmt.ColumnFloat(0); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(1); got != 7 {
if got := stmt.ColumnFloat(1); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(2); got != 7 {
t.Errorf("got %v, want 7", got)
}
var got []float64
if err := stmt.ColumnJSON(2, &got); err != nil {
if err := stmt.ColumnJSON(3, &got); err != nil {
t.Error(err)
}
if !slices.Equal(got, []float64{6.25, 10, 13.75}) {
@@ -57,6 +61,7 @@ func TestRegister_percentile(t *testing.T) {
stmt, _, err = db.Prepare(`
SELECT
median(x),
percentile(x, 50),
percentile_disc(x, 0.5),
percentile_cont(x, '[0.25, 0.5, 0.75]')
FROM data
@@ -71,8 +76,11 @@ func TestRegister_percentile(t *testing.T) {
if got := stmt.ColumnFloat(1); got != 4 {
t.Errorf("got %v, want 4", got)
}
if got := stmt.ColumnFloat(2); got != 4 {
t.Errorf("got %v, want 4", got)
}
var got []float64
if err := stmt.ColumnJSON(2, &got); err != nil {
if err := stmt.ColumnJSON(3, &got); err != nil {
t.Error(err)
}
if !slices.Equal(got, []float64{4, 4, 4}) {
@@ -84,6 +92,7 @@ func TestRegister_percentile(t *testing.T) {
stmt, _, err = db.Prepare(`
SELECT
median(x),
percentile(x, 50),
percentile_disc(x, 0.5),
percentile_cont(x, '[0.25, 0.5, 0.75]')
FROM data
@@ -101,6 +110,9 @@ func TestRegister_percentile(t *testing.T) {
if got := stmt.ColumnType(2); got != sqlite3.NULL {
t.Error("want NULL")
}
if got := stmt.ColumnType(3); got != sqlite3.NULL {
t.Error("want NULL")
}
}
stmt.Close()

View File

@@ -53,6 +53,7 @@ import (
// Register registers statistics functions.
func Register(db *sqlite3.Conn) error {
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
const order = sqlite3.SELFORDER1 | flags
return errors.Join(
db.CreateWindowFunction("var_pop", 1, flags, newVariance(var_pop)),
db.CreateWindowFunction("var_samp", 1, flags, newVariance(var_samp)),
@@ -71,9 +72,10 @@ func Register(db *sqlite3.Conn) error {
db.CreateWindowFunction("regr_intercept", 2, flags, newCovariance(regr_intercept)),
db.CreateWindowFunction("regr_count", 2, flags, newCovariance(regr_count)),
db.CreateWindowFunction("regr_json", 2, flags, newCovariance(regr_json)),
db.CreateWindowFunction("median", 1, flags, newPercentile(median)),
db.CreateWindowFunction("percentile_cont", 2, flags, newPercentile(percentile_cont)),
db.CreateWindowFunction("percentile_disc", 2, flags, newPercentile(percentile_disc)),
db.CreateWindowFunction("median", 1, order, newPercentile(median)),
db.CreateWindowFunction("percentile", 2, order, newPercentile(percentile_100)),
db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)),
db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)),
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
}