diff --git a/ext/stats/mode.go b/ext/stats/mode.go index 21eedba..1ba3c6e 100644 --- a/ext/stats/mode.go +++ b/ext/stats/mode.go @@ -19,8 +19,8 @@ type mode struct { func (m mode) Value(ctx sqlite3.Context) { var ( - max = 0 typ = sqlite3.NULL + max uint i64 int64 f64 float64 str string @@ -32,7 +32,6 @@ func (m mode) Value(ctx sqlite3.Context) { i64 = k } } - f64 = float64(i64) for k, v := range m.reals { if v > max || v == max && k < f64 { typ = sqlite3.FLOAT @@ -66,33 +65,45 @@ func (m mode) Value(ctx sqlite3.Context) { } } -func (b *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { +func (m *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { switch arg[0].Type() { case sqlite3.INTEGER: - b.ints.add(arg[0].Int64()) + if m.reals == nil { + m.ints.add(arg[0].Int64()) + break + } + fallthrough case sqlite3.FLOAT: - b.reals.add(arg[0].Float()) + m.reals.add(arg[0].Float()) + for k, v := range m.ints { + m.reals[float64(k)] += v + } + m.ints = nil case sqlite3.TEXT: - b.texts.add(arg[0].Text()) + m.texts.add(arg[0].Text()) case sqlite3.BLOB: - b.blobs.add(string(arg[0].RawBlob())) + m.blobs.add(string(arg[0].RawBlob())) } } -func (b *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { +func (m *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { switch arg[0].Type() { case sqlite3.INTEGER: - b.ints.del(arg[0].Int64()) + if m.reals == nil { + m.ints.del(arg[0].Int64()) + break + } + fallthrough case sqlite3.FLOAT: - b.reals.del(arg[0].Float()) + m.reals.del(arg[0].Float()) case sqlite3.TEXT: - b.texts.del(arg[0].Text()) + m.texts.del(arg[0].Text()) case sqlite3.BLOB: - b.blobs.del(string(arg[0].RawBlob())) + m.blobs.del(string(arg[0].RawBlob())) } } -type counter[T comparable] map[T]int +type counter[T comparable] map[T]uint func (c *counter[T]) add(k T) { if (*c) == nil { @@ -102,11 +113,9 @@ func (c *counter[T]) add(k T) { } func (c counter[T]) del(k T) { - switch n := c[k]; n { - default: - c[k] = n - 1 - case 1: + if n := c[k]; n == 1 { delete(c, k) - case 0: + } else { + c[k] = n - 1 } } diff --git a/ext/stats/mode_test.go b/ext/stats/mode_test.go index 3e2bfae..63a440a 100644 --- a/ext/stats/mode_test.go +++ b/ext/stats/mode_test.go @@ -82,4 +82,20 @@ func TestRegister_mode(t *testing.T) { for stmt.Step() { } stmt.Close() + + stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (?), (?), (?), (?), (?))`) + if err != nil { + t.Fatal(err) + } + stmt.BindInt(1, 1) + stmt.BindInt(2, 1) + stmt.BindInt(3, 2) + stmt.BindFloat(4, 2) + stmt.BindFloat(5, 2) + if stmt.Step() { + if got := stmt.ColumnInt(0); got != 2 { + t.Errorf("got %v, want 2", got) + } + } + stmt.Close() }