Files
sqlite3/ext/stats/stats_test.go

306 lines
6.7 KiB
Go
Raw Permalink Normal View History

2023-12-11 14:48:15 +00:00
package stats_test
2023-08-31 16:30:52 +01:00
import (
"math"
2025-04-04 10:56:12 +01:00
"os"
2023-08-31 16:30:52 +01:00
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
2023-12-11 14:48:15 +00:00
"github.com/ncruces/go-sqlite3/ext/stats"
2024-06-02 10:33:20 +01:00
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
2023-08-31 16:30:52 +01:00
)
2024-08-30 01:27:22 +01:00
func TestMain(m *testing.M) {
2024-07-08 12:06:57 +01:00
sqlite3.AutoExtension(stats.Register)
2025-04-04 10:56:12 +01:00
os.Exit(m.Run())
2024-07-08 12:06:57 +01:00
}
2023-09-01 02:26:30 +01:00
func TestRegister_variance(t *testing.T) {
2023-08-31 16:30:52 +01:00
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
2024-04-04 01:25:52 +01:00
err = db.Exec(`CREATE TABLE data (x)`)
2023-08-31 16:30:52 +01:00
if err != nil {
t.Fatal(err)
}
2025-01-20 14:39:36 +00:00
stmt, _, err := db.Prepare(`SELECT stddev_pop(x) FROM data`)
if err != nil {
t.Fatal(err)
}
2025-07-17 01:11:16 +01:00
if !stmt.Step() {
t.Fatal(stmt.Err())
} else if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
2025-01-20 14:39:36 +00:00
}
stmt.Close()
2023-09-01 02:26:30 +01:00
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
2023-08-31 16:30:52 +01:00
if err != nil {
t.Fatal(err)
}
2025-01-20 14:39:36 +00:00
stmt, _, err = db.Prepare(`
2023-09-01 02:26:30 +01:00
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
2025-01-22 12:09:20 +00:00
stddev_samp(x), stddev_pop(x),
skewness_samp(x), skewness_pop(x),
kurtosis_samp(x), kurtosis_pop(x)
2023-09-01 02:26:30 +01:00
FROM data`)
2023-08-31 16:30:52 +01:00
if err != nil {
t.Fatal(err)
}
2025-07-17 01:11:16 +01:00
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
2023-08-31 16:30:52 +01:00
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
}
if got := stmt.ColumnFloat(1); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(2); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := stmt.ColumnFloat(3); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
2025-01-22 12:09:20 +00:00
if got := stmt.ColumnFloat(6); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(7); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(8); float32(got) != -3.3 {
t.Errorf("got %v, want -3.3", got)
}
if got := stmt.ColumnFloat(9); got != -1.64 {
t.Errorf("got %v, want -1.64", got)
}
2023-08-31 16:30:52 +01:00
}
2024-06-06 16:41:20 +01:00
stmt.Close()
2023-08-31 16:30:52 +01:00
2025-01-20 14:39:36 +00:00
stmt, _, err = db.Prepare(`
SELECT
var_samp(x) OVER (ROWS 1 PRECEDING),
2025-01-22 12:09:20 +00:00
var_pop(x) OVER (ROWS 1 PRECEDING),
skewness_pop(x) OVER (ROWS 1 PRECEDING)
2025-01-20 14:39:36 +00:00
FROM data`)
2024-06-06 16:41:20 +01:00
if err != nil {
t.Fatal(err)
}
2023-08-31 16:30:52 +01:00
2024-06-06 16:41:20 +01:00
want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
2023-08-31 16:30:52 +01:00
}
}
2024-06-06 16:41:20 +01:00
stmt.Close()
2023-08-31 16:30:52 +01:00
}
2023-09-01 02:26:30 +01:00
func TestRegister_covariance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
2024-04-04 01:25:52 +01:00
err = db.Exec(`CREATE TABLE data (y, x)`)
2023-09-01 02:26:30 +01:00
if err != nil {
t.Fatal(err)
}
2025-01-20 14:39:36 +00:00
stmt, _, err := db.Prepare(`SELECT regr_count(y, x), regr_json(y, x) FROM data`)
if err != nil {
t.Fatal(err)
}
2025-07-17 01:11:16 +01:00
if !stmt.Step() {
t.Fatal(stmt.Err())
} else {
2025-01-20 14:39:36 +00:00
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want 0", got)
}
if got := stmt.ColumnType(1); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)
}
}
stmt.Close()
2024-01-10 11:07:16 +00:00
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
2023-09-01 02:26:30 +01:00
if err != nil {
t.Fatal(err)
}
2025-01-20 14:39:36 +00:00
stmt, _, err = db.Prepare(`SELECT
2024-01-10 11:07:16 +00:00
corr(y, x), covar_samp(y, x), covar_pop(y, x),
regr_avgy(y, x), regr_avgx(y, x),
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x),
2024-01-25 11:03:33 +00:00
regr_count(y, x), regr_json(y, x)
2024-01-09 03:20:59 +00:00
FROM data`)
2023-09-01 02:26:30 +01:00
if err != nil {
t.Fatal(err)
}
2025-07-04 17:24:45 +01:00
if !stmt.Step() {
t.Fatal(stmt.Err())
}
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
t.Errorf("got %v, want 21.25", got)
}
if got := stmt.ColumnFloat(2); got != 17 {
t.Errorf("got %v, want 17", got)
}
if got := stmt.ColumnFloat(3); got != 4.2 {
t.Errorf("got %v, want 4.2", got)
}
if got := stmt.ColumnFloat(4); got != 75 {
t.Errorf("got %v, want 75", got)
}
if got := stmt.ColumnFloat(5); got != 14.8 {
t.Errorf("got %v, want 14.8", got)
}
if got := stmt.ColumnFloat(6); got != 500 {
t.Errorf("got %v, want 500", got)
}
if got := stmt.ColumnFloat(7); got != 85 {
t.Errorf("got %v, want 85", got)
}
if got := stmt.ColumnFloat(8); got != 0.17 {
t.Errorf("got %v, want 0.17", got)
}
if got := stmt.ColumnFloat(9); got != -8.55 {
t.Errorf("got %v, want -8.55", got)
}
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
t.Errorf("got %v, want 0.9763513513513513", got)
}
if got := stmt.ColumnInt(11); got != 5 {
t.Errorf("got %v, want 5", got)
}
var a map[string]float64
if err := stmt.ColumnJSON(12, &a); err != nil {
t.Error(err)
} else if got := a["count"]; got != 5 {
t.Errorf("got %v, want 5", got)
2023-09-01 02:26:30 +01:00
}
2024-06-06 16:41:20 +01:00
stmt.Close()
2023-09-01 02:26:30 +01:00
2025-01-20 14:39:36 +00:00
stmt, _, err = db.Prepare(`
SELECT
covar_samp(y, x) OVER (ROWS 1 PRECEDING),
covar_pop(y, x) OVER (ROWS 1 PRECEDING),
regr_avgx(y, x) OVER (ROWS 1 PRECEDING)
FROM data`)
2024-06-06 16:41:20 +01:00
if err != nil {
t.Fatal(err)
}
2023-09-01 02:26:30 +01:00
2024-06-06 16:41:20 +01:00
want := [...]float64{0, 10, 30, 75, 22.5}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
2023-09-01 02:26:30 +01:00
}
}
2025-01-20 14:39:36 +00:00
if stmt.Err() != nil {
t.Fatal(stmt.Err())
}
2024-06-06 16:41:20 +01:00
stmt.Close()
2023-09-01 02:26:30 +01:00
}
2024-01-10 12:27:19 +00:00
func Benchmark_average(b *testing.B) {
2024-05-03 12:38:40 +01:00
sqlite3.Initialize()
b.ResetTimer()
2024-01-10 12:27:19 +00:00
db, err := sqlite3.Open(":memory:")
if err != nil {
b.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT avg(value) FROM generate_series(0, ?)`)
if err != nil {
b.Fatal(err)
}
defer stmt.Close()
err = stmt.BindInt(1, b.N)
if err != nil {
b.Fatal(err)
}
2025-07-17 01:11:16 +01:00
if !stmt.Step() {
b.Fatal(stmt.Err())
} else {
2024-01-10 12:27:19 +00:00
want := float64(b.N) / 2
if got := stmt.ColumnFloat(0); got != want {
b.Errorf("got %v, want %v", got, want)
}
}
err = stmt.Err()
if err != nil {
b.Error(err)
}
}
func Benchmark_variance(b *testing.B) {
2024-05-03 12:38:40 +01:00
sqlite3.Initialize()
b.ResetTimer()
2024-01-10 12:27:19 +00:00
db, err := sqlite3.Open(":memory:")
if err != nil {
b.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT var_pop(value) FROM generate_series(0, ?)`)
if err != nil {
b.Fatal(err)
}
defer stmt.Close()
err = stmt.BindInt(1, b.N)
if err != nil {
b.Fatal(err)
}
2025-07-17 01:11:16 +01:00
if !stmt.Step() {
b.Fatal(stmt.Err())
} else if b.N > 100 {
2024-01-10 12:27:19 +00:00
want := float64(b.N*b.N) / 12
if got := stmt.ColumnFloat(0); want > (got-want)*float64(b.N) {
b.Errorf("got %v, want %v", got, want)
}
}
err = stmt.Err()
if err != nil {
b.Error(err)
}
}