mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Fix #215.
This commit is contained in:
112
ext/stats/mode.go
Normal file
112
ext/stats/mode.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package stats
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
func newMode() sqlite3.AggregateFunction {
|
||||
return &mode{}
|
||||
}
|
||||
|
||||
type mode struct {
|
||||
ints counter[int64]
|
||||
reals counter[float64]
|
||||
texts counter[string]
|
||||
blobs counter[string]
|
||||
}
|
||||
|
||||
func (m mode) Value(ctx sqlite3.Context) {
|
||||
var (
|
||||
max = 0
|
||||
typ = sqlite3.NULL
|
||||
i64 int64
|
||||
f64 float64
|
||||
str string
|
||||
)
|
||||
for k, v := range m.ints {
|
||||
if v > max || v == max && k < i64 {
|
||||
typ = sqlite3.INTEGER
|
||||
max = v
|
||||
i64 = k
|
||||
}
|
||||
}
|
||||
f64 = float64(i64)
|
||||
for k, v := range m.reals {
|
||||
if v > max || v == max && k < f64 {
|
||||
typ = sqlite3.FLOAT
|
||||
max = v
|
||||
f64 = k
|
||||
}
|
||||
}
|
||||
for k, v := range m.texts {
|
||||
if v > max || v == max && typ == sqlite3.TEXT && k < str {
|
||||
typ = sqlite3.TEXT
|
||||
max = v
|
||||
str = k
|
||||
}
|
||||
}
|
||||
for k, v := range m.blobs {
|
||||
if v > max || v == max && typ == sqlite3.BLOB && k < str {
|
||||
typ = sqlite3.BLOB
|
||||
max = v
|
||||
str = k
|
||||
}
|
||||
}
|
||||
switch typ {
|
||||
case sqlite3.INTEGER:
|
||||
ctx.ResultInt64(i64)
|
||||
case sqlite3.FLOAT:
|
||||
ctx.ResultFloat(f64)
|
||||
case sqlite3.TEXT:
|
||||
ctx.ResultText(str)
|
||||
case sqlite3.BLOB:
|
||||
ctx.ResultBlob(unsafe.Slice(unsafe.StringData(str), len(str)))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg[0].Type() {
|
||||
case sqlite3.INTEGER:
|
||||
b.ints.add(arg[0].Int64())
|
||||
case sqlite3.FLOAT:
|
||||
b.reals.add(arg[0].Float())
|
||||
case sqlite3.TEXT:
|
||||
b.texts.add(arg[0].Text())
|
||||
case sqlite3.BLOB:
|
||||
b.blobs.add(string(arg[0].RawBlob()))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
switch arg[0].Type() {
|
||||
case sqlite3.INTEGER:
|
||||
b.ints.del(arg[0].Int64())
|
||||
case sqlite3.FLOAT:
|
||||
b.reals.del(arg[0].Float())
|
||||
case sqlite3.TEXT:
|
||||
b.texts.del(arg[0].Text())
|
||||
case sqlite3.BLOB:
|
||||
b.blobs.del(string(arg[0].RawBlob()))
|
||||
}
|
||||
}
|
||||
|
||||
type counter[T comparable] map[T]int
|
||||
|
||||
func (c *counter[T]) add(k T) {
|
||||
if (*c) == nil {
|
||||
(*c) = make(counter[T])
|
||||
}
|
||||
(*c)[k]++
|
||||
}
|
||||
|
||||
func (c counter[T]) del(k T) {
|
||||
switch n := c[k]; n {
|
||||
default:
|
||||
c[k] = n - 1
|
||||
case 1:
|
||||
delete(c, k)
|
||||
case 0:
|
||||
}
|
||||
}
|
||||
85
ext/stats/mode_test.go
Normal file
85
ext/stats/mode_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package stats_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func TestRegister_mode(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := sqlite3.Open(":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
stmt, _, err := db.Prepare(`SELECT mode(column1) FROM (VALUES (NULL), (1), (NULL), (2), (NULL), (3), (3))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnInt(0); got != 3 {
|
||||
t.Errorf("got %v, want 3", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (1), (1), (2), (2), (3))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnInt(0); got != 1 {
|
||||
t.Errorf("got %v, want 1", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (0.5), (1), (2.5), (2), (2.5))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnFloat(0); got != 2.5 {
|
||||
t.Errorf("got %v, want 2.5", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES ('red'), ('green'), ('blue'), ('red'))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnText(0); got != "red" {
|
||||
t.Errorf("got %q, want red", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (X'cafebabe'), ('green'), ('blue'), (X'cafebabe'))`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stmt.Step() {
|
||||
if got := stmt.ColumnText(0); got != "\xca\xfe\xba\xbe" {
|
||||
t.Errorf("got %q, want cafebabe", got)
|
||||
}
|
||||
}
|
||||
stmt.Close()
|
||||
|
||||
stmt, _, err = db.Prepare(`
|
||||
SELECT mode(column1) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING)
|
||||
FROM (VALUES (1), (1), (2.5), ('blue'), (X'cafebabe'), (1), (1))
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for stmt.Step() {
|
||||
}
|
||||
stmt.Close()
|
||||
}
|
||||
@@ -18,9 +18,11 @@
|
||||
// - regr_slope: slope of the least-squares-fit linear equation
|
||||
// - regr_intercept: y-intercept of the least-squares-fit linear equation
|
||||
// - regr_json: all regr stats in a JSON object
|
||||
// - percentile_disc: discrete percentile
|
||||
// - percentile_cont: continuous percentile
|
||||
// - median: median value
|
||||
// - percentile_disc: discrete quantile
|
||||
// - percentile_cont: continuous quantile
|
||||
// - percentile: continuous percentile
|
||||
// - median: middle value
|
||||
// - mode: most frequent value
|
||||
// - every: boolean and
|
||||
// - some: boolean or
|
||||
//
|
||||
@@ -77,7 +79,8 @@ func Register(db *sqlite3.Conn) error {
|
||||
db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)),
|
||||
db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)),
|
||||
db.CreateWindowFunction("every", 1, flags, newBoolean(every)),
|
||||
db.CreateWindowFunction("some", 1, flags, newBoolean(some)))
|
||||
db.CreateWindowFunction("some", 1, flags, newBoolean(some)),
|
||||
db.CreateWindowFunction("mode", 1, order, newMode))
|
||||
}
|
||||
|
||||
const (
|
||||
|
||||
Reference in New Issue
Block a user