This commit is contained in:
Nuno Cruces
2023-11-23 09:54:18 +00:00
parent 9bb01d1f8b
commit f2d6bdb8b7
11 changed files with 349 additions and 16 deletions

View File

@@ -97,7 +97,7 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error {
case k == reflect.String:
ctx.ResultText(v.String())
case (k == reflect.Slice || k == reflect.Array) &&
case (k == reflect.Slice || k == reflect.Array && v.CanAddr()) &&
v.Type().Elem().Kind() == reflect.Uint8:
ctx.ResultBlob(v.Bytes())

View File

@@ -3,6 +3,9 @@ package array_test
import (
"fmt"
"log"
"math"
"reflect"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
@@ -47,3 +50,43 @@ func Example() {
// geopoly_contains_point
// geopoly_within
}
func Test_cursor_Column(t *testing.T) {
db, err := driver.Open(":memory:", func(c *sqlite3.Conn) error {
array.Register(c)
return nil
})
if err != nil {
t.Fatal(err)
}
defer db.Close()
rows, err := db.Query(`
SELECT rowid, value FROM array(?)`,
sqlite3.Pointer(&[...]any{nil, true, 1, uint(2), math.Pi, "text", []byte{1, 2, 3}}))
if err != nil {
t.Fatal(err)
}
defer rows.Close()
want := []string{"nil", "int64", "int64", "int64", "float64", "string", "[]uint8"}
for rows.Next() {
var id, val any
err := rows.Scan(&id, &val)
if err != nil {
t.Fatal(err)
}
if want := want[0]; val == nil {
if want != "nil" {
t.Errorf("got nil, want %s", want)
}
} else if got := reflect.TypeOf(val).String(); got != want {
t.Errorf("got %s, want %s", got, want)
}
want = want[1:]
}
if err := rows.Err(); err != nil {
log.Fatal(err)
}
}

View File

@@ -11,19 +11,24 @@ import (
"fmt"
"io"
"math"
"os"
"strings"
"github.com/ncruces/go-sqlite3"
)
// Register registers the CSV virtual table.
//
// If a filename is specified, `os.Open` is used to read it from disk.
func Register(db *sqlite3.Conn) {
RegisterOpen(db, func(name string) (io.ReaderAt, error) {
return os.Open(name)
})
}
// RegisterOpen registers the CSV virtual table.
// If a filename is specified, open is used to open the file.
// To open the file from disk, use:
//
// csv.Register(c, os.Open)
func Register[T io.ReaderAt](db *sqlite3.Conn, open func(name string) (T, error)) {
declare := func(db *sqlite3.Conn, arg ...string) (*table, error) {
func RegisterOpen(db *sqlite3.Conn, open func(name string) (io.ReaderAt, error)) {
declare := func(db *sqlite3.Conn, arg ...string) (_ *table, err error) {
var (
filename string
data string
@@ -31,8 +36,8 @@ func Register[T io.ReaderAt](db *sqlite3.Conn, open func(name string) (T, error)
header bool
columns int = -1
comma rune = ','
err error
done = map[string]struct{}{}
done = map[string]struct{}{}
)
for _, arg := range arg[3:] {
@@ -81,19 +86,30 @@ func Register[T io.ReaderAt](db *sqlite3.Conn, open func(name string) (T, error)
comma: comma,
header: header,
}
defer func() {
if err != nil {
table.Close()
}
}()
if schema == "" && (header || columns < 0) {
csv := table.newReader()
row, err := csv.Read()
if err != nil {
table.Close()
return nil, err
}
schema = getSchema(header, columns, row)
}
err = db.DeclareVtab(schema)
return table, err
if err != nil {
return nil, err
}
err = db.VtabConfig(sqlite3.VTAB_DIRECTONLY)
if err != nil {
return nil, err
}
return table, nil
}
sqlite3.CreateModule(db, "csv", declare, declare)
@@ -123,6 +139,17 @@ func (t *table) Open() (sqlite3.VTabCursor, error) {
return &cursor{table: t}, nil
}
func (t *table) Rename(new string) error {
return nil
}
func (t *table) Integrity(schema, table string, flags int) (err error) {
if flags&1 == 0 {
_, err = t.newReader().ReadAll()
}
return err
}
func (t *table) newReader() *csv.Reader {
csv := csv.NewReader(io.NewSectionReader(t.r, 0, math.MaxInt64))
csv.ReuseRecord = true

View File

@@ -3,7 +3,7 @@ package csv_test
import (
"fmt"
"log"
"os"
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
@@ -17,12 +17,13 @@ func Example() {
}
defer db.Close()
csv.Register(db, os.Open)
csv.Register(db)
err = db.Exec(`
CREATE VIRTUAL TABLE IF NOT EXISTS eurofxref USING csv(
filename = 'eurofxref.csv',
filename = 'testdata/eurofxref.csv',
header = YES,
columns = 42,
)`)
if err != nil {
log.Fatal(err)
@@ -48,3 +49,103 @@ func Example() {
// Output:
// On Twosday, 1€ = $1.1342
}
func TestRegister(t *testing.T) {
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
csv.Register(db)
data := `
"Rob" "Pike" rob
"Ken" Thompson ken
Robert "Griesemer" "gri"`
err = db.Exec(`
CREATE VIRTUAL TABLE temp.users USING csv(
data = ` + sqlite3.Quote(data) + `,
schema = 'CREATE TABLE x(first_name, last_name, username)',
comma = '\t'
)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT * FROM temp.users WHERE rowid = 1 ORDER BY username`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if !stmt.Step() {
t.Fatal("no rows")
}
if got := stmt.ColumnText(1); got != "Pike" {
t.Errorf("got %q want Pike", got)
}
if stmt.Step() {
t.Fatal("more rows")
}
err = db.Exec(`ALTER TABLE temp.users RENAME TO csv`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`PRAGMA integrity_check`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`DROP TABLE temp.csv`)
if err != nil {
log.Fatal(err)
}
}
func TestRegister_errors(t *testing.T) {
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
csv.Register(db)
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv()`)
if err == nil {
t.Fatal(err)
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', data='abc')`)
if err == nil {
t.Fatal(err)
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', xpto='abc')`)
if err == nil {
t.Fatal(err)
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', comma='"')`)
if err == nil {
t.Fatal(err)
} else {
t.Log(err)
}
err = db.Exec(`CREATE VIRTUAL TABLE temp.users USING csv(data='abc', header=tru)`)
if err == nil {
t.Fatal(err)
} else {
t.Log(err)
}
}

View File

@@ -22,13 +22,13 @@ func uintParam(key, val string) (int, error) {
}
func boolParam(key, val string) (bool, error) {
if val == "" || val == "0" ||
if val == "" || val == "1" ||
strings.EqualFold(val, "true") ||
strings.EqualFold(val, "yes") ||
strings.EqualFold(val, "on") {
return true, nil
}
if val == "1" ||
if val == "0" ||
strings.EqualFold(val, "false") ||
strings.EqualFold(val, "no") ||
strings.EqualFold(val, "off") {

102
ext/csv/params_test.go Normal file
View File

@@ -0,0 +1,102 @@
package csv
import "testing"
func Test_uintParam(t *testing.T) {
tests := []struct {
arg string
key string
val int
err bool
}{
{"columns 1", "columns 1", 0, true},
{"columns = 1", "columns", 1, false},
{"columns\t= 2", "columns", 2, false},
{" columns = 3", "columns", 3, false},
{" columns = -1", "columns", 0, true},
{" columns = 32768", "columns", 0, true},
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
}
got, err := uintParam(key, val)
if (err != nil) != tt.err {
t.Fatalf("uintParam() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("uintParam() = %v, want %v", got, tt.val)
}
})
}
}
func Test_boolParam(t *testing.T) {
tests := []struct {
arg string
key string
val bool
err bool
}{
{"header", "header", true, false},
{"header\t= 1", "header", true, false},
{" header = 0", "header", false, false},
{" header = TrUe", "header", true, false},
{" header = FaLsE", "header", false, false},
{" header = Yes", "header", true, false},
{" header = nO", "header", false, false},
{" header = On", "header", true, false},
{" header = Off", "header", false, false},
{" header = T", "header", false, true},
{" header = f", "header", false, true},
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
}
got, err := boolParam(key, val)
if (err != nil) != tt.err {
t.Fatalf("boolParam() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("boolParam() = %v, want %v", got, tt.val)
}
})
}
}
func Test_runeParam(t *testing.T) {
tests := []struct {
arg string
key string
val rune
err bool
}{
{"comma", "comma", 0, true},
{"comma\t= ,", "comma", ',', false},
{" comma = ;", "comma", ';', false},
{" comma = ;;", "comma", 0, true},
{` comma = '\t`, "comma", 0, true},
{` comma = '\t'`, "comma", '\t', false},
{` comma = "\t"`, "comma", '\t', false},
}
for _, tt := range tests {
t.Run(tt.arg, func(t *testing.T) {
key, val := getParam(tt.arg)
if key != tt.key {
t.Errorf("getParam() %v, want err %v", key, tt.key)
}
got, err := runeParam(key, val)
if (err != nil) != tt.err {
t.Fatalf("runeParam() error = %v, want err %v", err, tt.err)
}
if got != tt.val {
t.Errorf("runeParam() = %v, want %v", got, tt.val)
}
})
}
}

24
ext/csv/schema_test.go Normal file
View File

@@ -0,0 +1,24 @@
package csv
import "testing"
func Test_getSchema(t *testing.T) {
tests := []struct {
header bool
columns int
row []string
want string
}{
{true, 2, nil, `CREATE TABLE x(c1,c2)`},
{false, 2, nil, `CREATE TABLE x(c1,c2)`},
{true, 3, []string{"abc", ""}, `CREATE TABLE x("abc",c2,c3)`},
{true, 1, []string{"abc", "def"}, `CREATE TABLE x("abc")`},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
if got := getSchema(tt.header, tt.columns, tt.row); got != tt.want {
t.Errorf("getSchema() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -185,6 +185,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
resultErrorBig: getFun("sqlite3_result_error_toobig"),
createModule: getFun("sqlite3_create_module_go"),
declareVTab: getFun("sqlite3_declare_vtab"),
vtabConfig: getFun("sqlite3_vtab_config_go"),
vtabRHSValue: getFun("sqlite3_vtab_rhs_value"),
}
if err != nil {
@@ -412,6 +413,7 @@ type sqliteAPI struct {
resultErrorBig api.Function
createModule api.Function
declareVTab api.Function
vtabConfig api.Function
vtabRHSValue api.Function
destructor uint32
}

29
vtab.go
View File

@@ -66,6 +66,9 @@ func implements[T any](typ reflect.Type) bool {
return typ.Implements(reflect.TypeOf(ptr).Elem())
}
// DeclareVtab declares the schema of a virtual table.
//
// https://sqlite.org/c3ref/declare_vtab.html
func (c *Conn) DeclareVtab(sql string) error {
// defer c.arena.reset()
sqlPtr := c.arena.string(sql)
@@ -73,6 +76,32 @@ func (c *Conn) DeclareVtab(sql string) error {
return c.error(r)
}
// IndexConstraintOp is a virtual table constraint operator code.
//
// https://sqlite.org/c3ref/c_vtab_constraint_support.html
type VtabConfigOption uint8
const (
VTAB_CONSTRAINT_SUPPORT VtabConfigOption = 1
VTAB_INNOCUOUS VtabConfigOption = 2
VTAB_DIRECTONLY VtabConfigOption = 3
VTAB_USES_ALL_SCHEMAS VtabConfigOption = 4
)
// VtabConfig configures various facets of the virtual table interface.
//
// https://sqlite.org/c3ref/vtab_config.html
func (c *Conn) VtabConfig(op VtabConfigOption, args ...any) error {
var i uint64
if op == VTAB_CONSTRAINT_SUPPORT && len(args) > 0 {
if b, ok := args[0].(bool); ok && b {
i = 1
}
}
r := c.call(c.api.vtabConfig, uint64(c.handle), uint64(op), i)
return c.error(r)
}
// VTabConstructor is a virtual table constructor function.
type VTabConstructor[T VTab] func(db *Conn, arg ...string) (T, error)

View File

@@ -56,6 +56,8 @@ func (seriesTable) BestIndex(idx *sqlite3.IndexInfo) error {
}
}
}
idx.IdxNum = 1
idx.IdxStr = "idx"
return nil
}
@@ -71,6 +73,9 @@ type seriesCursor struct {
}
func (cur *seriesCursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
if idxNum != 1 || idxStr != "idx" {
return nil
}
cur.start = 0
cur.stop = 1000
cur.step = 1