From f2d6bdb8b780df9700d2e99b48e8d45b3447a3f5 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 23 Nov 2023 09:54:18 +0000 Subject: [PATCH] Tests. --- ext/array/array.go | 2 +- ext/array/array_test.go | 43 +++++++++++ ext/csv/csv.go | 47 +++++++++--- ext/csv/csv_test.go | 107 ++++++++++++++++++++++++++- ext/csv/params.go | 4 +- ext/csv/params_test.go | 102 +++++++++++++++++++++++++ ext/csv/schema_test.go | 24 ++++++ ext/csv/{ => testdata}/eurofxref.csv | 0 sqlite.go | 2 + vtab.go | 29 ++++++++ vtab_test.go | 5 ++ 11 files changed, 349 insertions(+), 16 deletions(-) create mode 100644 ext/csv/params_test.go create mode 100644 ext/csv/schema_test.go rename ext/csv/{ => testdata}/eurofxref.csv (100%) diff --git a/ext/array/array.go b/ext/array/array.go index 828e396..0e48fd9 100644 --- a/ext/array/array.go +++ b/ext/array/array.go @@ -97,7 +97,7 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error { case k == reflect.String: ctx.ResultText(v.String()) - case (k == reflect.Slice || k == reflect.Array) && + case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) && v.Type().Elem().Kind() == reflect.Uint8: ctx.ResultBlob(v.Bytes()) diff --git a/ext/array/array_test.go b/ext/array/array_test.go index acae086..d1d5b23 100644 --- a/ext/array/array_test.go +++ b/ext/array/array_test.go @@ -3,6 +3,9 @@ package array_test import ( "fmt" "log" + "math" + "reflect" + "testing" "github.com/ncruces/go-sqlite3" "github.com/ncruces/go-sqlite3/driver" @@ -47,3 +50,43 @@ func Example() { // geopoly_contains_point // geopoly_within } + +func Test_cursor_Column(t *testing.T) { + db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error { + array.Register(c) + return nil + }) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + rows, err := db.Query(` + SELECT rowid, value FROM array(?)`, + sqlite3.Pointer(&[...]any{nil, true, 1, uint(2), math.Pi, "text", []byte{1, 2, 3}})) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + want := []string{"nil", "int64", "int64", "int64", "float64", "string", "[]uint8"} + + for rows.Next() { + var id, val any + err := rows.Scan(&id, &val) + if err != nil { + t.Fatal(err) + } + if want := want[0]; val == nil { + if want != "nil" { + t.Errorf("got nil, want %s", want) + } + } else if got := reflect.TypeOf(val).String(); got != want { + t.Errorf("got %s, want %s", got, want) + } + want = want[1:] + } + if err := rows.Err(); err != nil { + log.Fatal(err) + } +} diff --git a/ext/csv/csv.go b/ext/csv/csv.go index 1b66832..1b4a214 100644 --- a/ext/csv/csv.go +++ b/ext/csv/csv.go @@ -11,19 +11,24 @@ import ( "fmt" "io" "math" + "os" "strings" "github.com/ncruces/go-sqlite3" ) // Register registers the CSV virtual table. -// +// If a filename is specified, `os.Open` is used to read it from disk. +func Register(db *sqlite3.Conn) { + RegisterOpen(db, func(name string) (io.ReaderAt, error) { + return os.Open(name) + }) +} + +// RegisterOpen registers the CSV virtual table. // If a filename is specified, open is used to open the file. -// To open the file from disk, use: -// -// csv.Register(c, os.Open) -func Register[T io.ReaderAt](db *sqlite3.Conn, open func(name string) (T, error)) { - declare := func(db *sqlite3.Conn, arg ...string) (*table, error) { +func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error)) { + declare := func(db *sqlite3.Conn, arg ...string) (_ *table, err error) { var ( filename string data string @@ -31,8 +36,8 @@ func Register[T io.ReaderAt](db *sqlite3.Conn, open func(name string) (T, error) header bool columns int = -1 comma rune = ',' - err error - done = map[string]struct{}{} + + done = map[string]struct{}{} ) for _, arg := range arg[3:] { @@ -81,19 +86,30 @@ func Register[T io.ReaderAt](db *sqlite3.Conn, open func(name string) (T, error) comma: comma, header: header, } + defer func() { + if err != nil { + table.Close() + } + }() if schema == "" && (header || columns < 0) { csv := table.newReader() row, err := csv.Read() if err != nil { - table.Close() return nil, err } schema = getSchema(header, columns, row) } err = db.DeclareVtab(schema) - return table, err + if err != nil { + return nil, err + } + err = db.VtabConfig(sqlite3.VTAB_DIRECTONLY) + if err != nil { + return nil, err + } + return table, nil } sqlite3.CreateModule(db, "csv", declare, declare) @@ -123,6 +139,17 @@ func (t *table) Open() (sqlite3.VTabCursor, error) { return &cursor{table: t}, nil } +func (t *table) Rename(new string) error { + return nil +} + +func (t *table) Integrity(schema, table string, flags int) (err error) { + if flags&1 == 0 { + _, err = t.newReader().ReadAll() + } + return err +} + func (t *table) newReader() *csv.Reader { csv := csv.NewReader(io.NewSectionReader(t.r, 0, math.MaxInt64)) csv.ReuseRecord = true diff --git a/ext/csv/csv_test.go b/ext/csv/csv_test.go index 2f5ab5a..aa8763a 100644 --- a/ext/csv/csv_test.go +++ b/ext/csv/csv_test.go @@ -3,7 +3,7 @@ package csv_test import ( "fmt" "log" - "os" + "testing" "github.com/ncruces/go-sqlite3" _ "github.com/ncruces/go-sqlite3/embed" @@ -17,12 +17,13 @@ func Example() { } defer db.Close() - csv.Register(db, os.Open) + csv.Register(db) err = db.Exec(` CREATE VIRTUAL TABLE IF NOT EXISTS eurofxref USING csv( - filename = 'eurofxref.csv', + filename = 'testdata/eurofxref.csv', header = YES, + columns = 42, )`) if err != nil { log.Fatal(err) @@ -48,3 +49,103 @@ func Example() { // Output: // On Twosday, 1€ = $1.1342 } + +func TestRegister(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + csv.Register(db) + + data := ` +"Rob" "Pike" rob +"Ken" Thompson ken +Robert "Griesemer" "gri"` + err = db.Exec(` + CREATE VIRTUAL TABLE temp.users USING csv( + data = ` + sqlite3.Quote(data) + `, + schema = 'CREATE TABLE x(first_name, last_name, username)', + comma = '\t' + )`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT * FROM temp.users WHERE rowid = 1 ORDER BY username`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if !stmt.Step() { + t.Fatal("no rows") + } + if got := stmt.ColumnText(1); got != "Pike" { + t.Errorf("got %q want Pike", got) + } + if stmt.Step() { + t.Fatal("more rows") + } + + err = db.Exec(`ALTER TABLE temp.users RENAME TO csv`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`PRAGMA integrity_check`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`DROP TABLE temp.csv`) + if err != nil { + log.Fatal(err) + } +} + +func TestRegister_errors(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + csv.Register(db) + + err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', data='abc')`) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', xpto='abc')`) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', comma='"')`) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } + + err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', header=tru)`) + if err == nil { + t.Fatal(err) + } else { + t.Log(err) + } +} diff --git a/ext/csv/params.go b/ext/csv/params.go index 35f59f6..7dbbd72 100644 --- a/ext/csv/params.go +++ b/ext/csv/params.go @@ -22,13 +22,13 @@ func uintParam(key, val string) (int, error) { } func boolParam(key, val string) (bool, error) { - if val == "" || val == "0" || + if val == "" || val == "1" || strings.EqualFold(val, "true") || strings.EqualFold(val, "yes") || strings.EqualFold(val, "on") { return true, nil } - if val == "1" || + if val == "0" || strings.EqualFold(val, "false") || strings.EqualFold(val, "no") || strings.EqualFold(val, "off") { diff --git a/ext/csv/params_test.go b/ext/csv/params_test.go new file mode 100644 index 0000000..97a6b4b --- /dev/null +++ b/ext/csv/params_test.go @@ -0,0 +1,102 @@ +package csv + +import "testing" + +func Test_uintParam(t *testing.T) { + tests := []struct { + arg string + key string + val int + err bool + }{ + {"columns 1", "columns 1", 0, true}, + {"columns = 1", "columns", 1, false}, + {"columns\t= 2", "columns", 2, false}, + {" columns = 3", "columns", 3, false}, + {" columns = -1", "columns", 0, true}, + {" columns = 32768", "columns", 0, true}, + } + for _, tt := range tests { + t.Run(tt.arg, func(t *testing.T) { + key, val := getParam(tt.arg) + if key != tt.key { + t.Errorf("getParam() %v, want err %v", key, tt.key) + } + got, err := uintParam(key, val) + if (err != nil) != tt.err { + t.Fatalf("uintParam() error = %v, want err %v", err, tt.err) + } + if got != tt.val { + t.Errorf("uintParam() = %v, want %v", got, tt.val) + } + }) + } +} + +func Test_boolParam(t *testing.T) { + tests := []struct { + arg string + key string + val bool + err bool + }{ + {"header", "header", true, false}, + {"header\t= 1", "header", true, false}, + {" header = 0", "header", false, false}, + {" header = TrUe", "header", true, false}, + {" header = FaLsE", "header", false, false}, + {" header = Yes", "header", true, false}, + {" header = nO", "header", false, false}, + {" header = On", "header", true, false}, + {" header = Off", "header", false, false}, + {" header = T", "header", false, true}, + {" header = f", "header", false, true}, + } + for _, tt := range tests { + t.Run(tt.arg, func(t *testing.T) { + key, val := getParam(tt.arg) + if key != tt.key { + t.Errorf("getParam() %v, want err %v", key, tt.key) + } + got, err := boolParam(key, val) + if (err != nil) != tt.err { + t.Fatalf("boolParam() error = %v, want err %v", err, tt.err) + } + if got != tt.val { + t.Errorf("boolParam() = %v, want %v", got, tt.val) + } + }) + } +} + +func Test_runeParam(t *testing.T) { + tests := []struct { + arg string + key string + val rune + err bool + }{ + {"comma", "comma", 0, true}, + {"comma\t= ,", "comma", ',', false}, + {" comma = ;", "comma", ';', false}, + {" comma = ;;", "comma", 0, true}, + {` comma = '\t`, "comma", 0, true}, + {` comma = '\t'`, "comma", '\t', false}, + {` comma = "\t"`, "comma", '\t', false}, + } + for _, tt := range tests { + t.Run(tt.arg, func(t *testing.T) { + key, val := getParam(tt.arg) + if key != tt.key { + t.Errorf("getParam() %v, want err %v", key, tt.key) + } + got, err := runeParam(key, val) + if (err != nil) != tt.err { + t.Fatalf("runeParam() error = %v, want err %v", err, tt.err) + } + if got != tt.val { + t.Errorf("runeParam() = %v, want %v", got, tt.val) + } + }) + } +} diff --git a/ext/csv/schema_test.go b/ext/csv/schema_test.go new file mode 100644 index 0000000..fa2cfdc --- /dev/null +++ b/ext/csv/schema_test.go @@ -0,0 +1,24 @@ +package csv + +import "testing" + +func Test_getSchema(t *testing.T) { + tests := []struct { + header bool + columns int + row []string + want string + }{ + {true, 2, nil, `CREATE TABLE x(c1,c2)`}, + {false, 2, nil, `CREATE TABLE x(c1,c2)`}, + {true, 3, []string{"abc", ""}, `CREATE TABLE x("abc",c2,c3)`}, + {true, 1, []string{"abc", "def"}, `CREATE TABLE x("abc")`}, + } + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := getSchema(tt.header, tt.columns, tt.row); got != tt.want { + t.Errorf("getSchema() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/ext/csv/eurofxref.csv b/ext/csv/testdata/eurofxref.csv similarity index 100% rename from ext/csv/eurofxref.csv rename to ext/csv/testdata/eurofxref.csv diff --git a/sqlite.go b/sqlite.go index 4f667ca..4ee8400 100644 --- a/sqlite.go +++ b/sqlite.go @@ -185,6 +185,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) { resultErrorBig: getFun("sqlite3_result_error_toobig"), createModule: getFun("sqlite3_create_module_go"), declareVTab: getFun("sqlite3_declare_vtab"), + vtabConfig: getFun("sqlite3_vtab_config_go"), vtabRHSValue: getFun("sqlite3_vtab_rhs_value"), } if err != nil { @@ -412,6 +413,7 @@ type sqliteAPI struct { resultErrorBig api.Function createModule api.Function declareVTab api.Function + vtabConfig api.Function vtabRHSValue api.Function destructor uint32 } diff --git a/vtab.go b/vtab.go index 788cbde..a242fb0 100644 --- a/vtab.go +++ b/vtab.go @@ -66,6 +66,9 @@ func implements[T any](typ reflect.Type) bool { return typ.Implements(reflect.TypeOf(ptr).Elem()) } +// DeclareVtab declares the schema of a virtual table. +// +// https://sqlite.org/c3ref/declare_vtab.html func (c *Conn) DeclareVtab(sql string) error { // defer c.arena.reset() sqlPtr := c.arena.string(sql) @@ -73,6 +76,32 @@ func (c *Conn) DeclareVtab(sql string) error { return c.error(r) } +// IndexConstraintOp is a virtual table constraint operator code. +// +// https://sqlite.org/c3ref/c_vtab_constraint_support.html +type VtabConfigOption uint8 + +const ( + VTAB_CONSTRAINT_SUPPORT VtabConfigOption = 1 + VTAB_INNOCUOUS VtabConfigOption = 2 + VTAB_DIRECTONLY VtabConfigOption = 3 + VTAB_USES_ALL_SCHEMAS VtabConfigOption = 4 +) + +// VtabConfig configures various facets of the virtual table interface. +// +// https://sqlite.org/c3ref/vtab_config.html +func (c *Conn) VtabConfig(op VtabConfigOption, args ...any) error { + var i uint64 + if op == VTAB_CONSTRAINT_SUPPORT && len(args) > 0 { + if b, ok := args[0].(bool); ok && b { + i = 1 + } + } + r := c.call(c.api.vtabConfig, uint64(c.handle), uint64(op), i) + return c.error(r) +} + // VTabConstructor is a virtual table constructor function. type VTabConstructor[T VTab] func(db *Conn, arg ...string) (T, error) diff --git a/vtab_test.go b/vtab_test.go index 51412e8..6cf0b24 100644 --- a/vtab_test.go +++ b/vtab_test.go @@ -56,6 +56,8 @@ func (seriesTable) BestIndex(idx *sqlite3.IndexInfo) error { } } } + idx.IdxNum = 1 + idx.IdxStr = "idx" return nil } @@ -71,6 +73,9 @@ type seriesCursor struct { } func (cur *seriesCursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { + if idxNum != 1 || idxStr != "idx" { + return nil + } cur.start = 0 cur.stop = 1000 cur.step = 1