Files
sqlite3/ext/stats/stats_test.go

174 lines
4.0 KiB
Go
Raw Normal View History

2023-12-11 14:48:15 +00:00
package stats_test
2023-08-31 16:30:52 +01:00
import (
"math"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
2023-12-11 14:48:15 +00:00
"github.com/ncruces/go-sqlite3/ext/stats"
2023-08-31 16:30:52 +01:00
)
2023-09-01 02:26:30 +01:00
func TestRegister_variance(t *testing.T) {
2023-08-31 16:30:52 +01:00
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
2023-12-11 14:48:15 +00:00
stats.Register(db)
2023-08-31 16:30:52 +01:00
2023-09-01 02:26:30 +01:00
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (x)`)
2023-08-31 16:30:52 +01:00
if err != nil {
t.Fatal(err)
}
2023-09-01 02:26:30 +01:00
err = db.Exec(`INSERT INTO data (x) VALUES (4), (7.0), ('13'), (NULL), (16)`)
2023-08-31 16:30:52 +01:00
if err != nil {
t.Fatal(err)
}
2023-09-01 02:26:30 +01:00
stmt, _, err := db.Prepare(`
SELECT
sum(x), avg(x),
var_samp(x), var_pop(x),
stddev_samp(x), stddev_pop(x)
FROM data`)
2023-08-31 16:30:52 +01:00
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
if got := stmt.ColumnFloat(0); got != 40 {
t.Errorf("got %v, want 40", got)
}
if got := stmt.ColumnFloat(1); got != 10 {
t.Errorf("got %v, want 10", got)
}
if got := stmt.ColumnFloat(2); got != 30 {
t.Errorf("got %v, want 30", got)
}
if got := stmt.ColumnFloat(3); got != 22.5 {
t.Errorf("got %v, want 22.5", got)
}
if got := stmt.ColumnFloat(4); got != math.Sqrt(30) {
t.Errorf("got %v, want √30", got)
}
if got := stmt.ColumnFloat(5); got != math.Sqrt(22.5) {
t.Errorf("got %v, want √22.5", got)
}
}
{
2023-09-01 02:26:30 +01:00
stmt, _, err := db.Prepare(`SELECT var_samp(x) OVER (ROWS 1 PRECEDING) FROM data`)
2023-08-31 16:30:52 +01:00
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 4.5, 18, 0, 0}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}
2023-09-01 02:26:30 +01:00
func TestRegister_covariance(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
2023-12-11 14:48:15 +00:00
stats.Register(db)
2023-09-01 02:26:30 +01:00
2024-01-10 11:07:16 +00:00
err = db.Exec(`CREATE TABLE IF NOT EXISTS data (y, x)`)
2023-09-01 02:26:30 +01:00
if err != nil {
t.Fatal(err)
}
2024-01-10 11:07:16 +00:00
err = db.Exec(`INSERT INTO data (y, x) VALUES (3, 70), (5, 80), (2, 60), (7, 90), (4, 75)`)
2023-09-01 02:26:30 +01:00
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT
2024-01-10 11:07:16 +00:00
corr(y, x), covar_samp(y, x), covar_pop(y, x),
regr_avgy(y, x), regr_avgx(y, x),
regr_syy(y, x), regr_sxx(y, x), regr_sxy(y, x),
regr_slope(y, x), regr_intercept(y, x), regr_r2(y, x),
regr_count(y, x)
2024-01-09 03:20:59 +00:00
FROM data`)
2023-09-01 02:26:30 +01:00
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
2023-09-02 00:48:55 +01:00
if got := stmt.ColumnFloat(0); got != 0.9881049293224639 {
t.Errorf("got %v, want 0.9881049293224639", got)
}
if got := stmt.ColumnFloat(1); got != 21.25 {
2023-09-01 02:26:30 +01:00
t.Errorf("got %v, want 21.25", got)
}
2023-09-02 00:48:55 +01:00
if got := stmt.ColumnFloat(2); got != 17 {
2023-09-01 02:26:30 +01:00
t.Errorf("got %v, want 17", got)
}
2024-01-10 11:07:16 +00:00
if got := stmt.ColumnFloat(3); got != 4.2 {
t.Errorf("got %v, want 4.2", got)
}
if got := stmt.ColumnFloat(4); got != 75 {
2024-01-09 03:20:59 +00:00
t.Errorf("got %v, want 75", got)
}
2024-01-10 11:07:16 +00:00
if got := stmt.ColumnFloat(5); got != 14.8 {
t.Errorf("got %v, want 14.8", got)
2024-01-09 03:20:59 +00:00
}
2024-01-10 11:07:16 +00:00
if got := stmt.ColumnFloat(6); got != 500 {
t.Errorf("got %v, want 500", got)
2024-01-09 03:20:59 +00:00
}
2024-01-10 11:07:16 +00:00
if got := stmt.ColumnFloat(7); got != 85 {
t.Errorf("got %v, want 85", got)
}
if got := stmt.ColumnFloat(8); got != 0.17 {
2024-01-09 03:20:59 +00:00
t.Errorf("got %v, want 0.17", got)
}
2024-01-10 11:07:16 +00:00
if got := stmt.ColumnFloat(9); got != -8.55 {
2024-01-09 03:20:59 +00:00
t.Errorf("got %v, want -8.55", got)
}
2024-01-10 11:07:16 +00:00
if got := stmt.ColumnFloat(10); got != 0.9763513513513513 {
t.Errorf("got %v, want 0.9763513513513513", got)
}
if got := stmt.ColumnInt(11); got != 5 {
t.Errorf("got %v, want 5", got)
}
2023-09-01 02:26:30 +01:00
}
{
2024-01-10 11:07:16 +00:00
stmt, _, err := db.Prepare(`SELECT covar_samp(y, x) OVER (ROWS 1 PRECEDING) FROM data`)
2023-09-01 02:26:30 +01:00
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
want := [...]float64{0, 10, 30, 75, 22.5}
for i := 0; stmt.Step(); i++ {
if got := stmt.ColumnFloat(0); got != want[i] {
t.Errorf("got %v, want %v", got, want[i])
}
if got := stmt.ColumnType(0); (got == sqlite3.FLOAT) != (want[i] != 0) {
t.Errorf("got %v, want %v", got, want[i])
}
}
}
}