diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index d4fbfc1..0cda728 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -20,6 +20,7 @@ jobs: uses: actions/setup-go@v3 with: go-version: stable + cache: true - name: Build run: go build -v ./... diff --git a/const.go b/const.go index f437bdb..0904b87 100644 --- a/const.go +++ b/const.go @@ -1,5 +1,7 @@ package sqlite3 +import "strconv" + const ( _OK = 0 /* Successful result */ _ROW = 100 /* sqlite3_step() has another row ready */ @@ -175,3 +177,20 @@ const ( BLOB Datatype = 4 NULL Datatype = 5 ) + +func (t Datatype) String() string { + const name = "INTEGERFLOATTEXTBLOBNULL" + switch t { + case INTEGER: + return name[0:7] + case FLOAT: + return name[7:12] + case TEXT: + return name[12:16] + case BLOB: + return name[16:20] + case NULL: + return name[20:24] + } + return strconv.FormatUint(uint64(t), 10) +} diff --git a/const_test.go b/const_test.go new file mode 100644 index 0000000..5cf97bf --- /dev/null +++ b/const_test.go @@ -0,0 +1,24 @@ +package sqlite3 + +import "testing" + +func TestDatatype_String(t *testing.T) { + tests := []struct { + data Datatype + want string + }{ + {INTEGER, "INTEGER"}, + {FLOAT, "FLOAT"}, + {TEXT, "TEXT"}, + {BLOB, "BLOB"}, + {NULL, "NULL"}, + {10, "10"}, + } + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.data.String(); got != tt.want { + t.Errorf("got %v, want %v", got, tt.want) + } + }) + } +} diff --git a/go.mod b/go.mod index 4744b87..cf63a4c 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,5 @@ go 1.19 require ( github.com/ncruces/julianday v0.1.4 github.com/tetratelabs/wazero v1.0.0-pre.8 + golang.org/x/sync v0.1.0 ) - -require golang.org/x/sync v0.1.0 diff --git a/stmt.go b/stmt.go index 4c9aff1..a7e4daf 100644 --- a/stmt.go +++ b/stmt.go @@ -49,6 +49,16 @@ func (s *Stmt) Err() error { return s.err } +func (s *Stmt) Exec() error { + for s.Step() { + } + err := s.Err() + if rerr := s.Reset(); err == nil { + err = rerr + } + return err +} + func (s *Stmt) BindBool(param int, value bool) error { if value { return s.BindInt64(param, 1) @@ -111,6 +121,15 @@ func (s *Stmt) BindNull(param int) error { return s.c.error(r[0]) } +func (s *Stmt) ColumnType(col int) Datatype { + r, err := s.c.api.columnType.Call(s.c.ctx, + uint64(s.handle), uint64(col)) + if err != nil { + panic(err) + } + return Datatype(r[0]) +} + func (s *Stmt) ColumnBool(col int) bool { if i := s.ColumnInt64(col); i != 0 { return true @@ -132,7 +151,7 @@ func (s *Stmt) ColumnInt64(col int) int64 { } func (s *Stmt) ColumnFloat(col int) float64 { - r, err := s.c.api.columnInteger.Call(s.c.ctx, + r, err := s.c.api.columnFloat.Call(s.c.ctx, uint64(s.handle), uint64(col)) if err != nil { panic(err) @@ -181,7 +200,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { panic(err) } s.err = s.c.error(r[0]) - return nil + return buf[0:0] } r, err = s.c.api.columnBytes.Call(s.c.ctx, diff --git a/stmt_test.go b/stmt_test.go new file mode 100644 index 0000000..cf6dc2c --- /dev/null +++ b/stmt_test.go @@ -0,0 +1,325 @@ +package sqlite3 + +import ( + "math" + "testing" +) + +func TestStmt(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`INSERT INTO test(col) VALUES(?)`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + err = stmt.BindBool(1, false) + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindBool(1, true) + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindInt(1, 2) + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindFloat(1, math.Pi) + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindNull(1) + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindText(1, "") + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindText(1, "text") + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindBlob(1, []byte("blob")) + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.BindBlob(1, nil) + if err != nil { + t.Fatal(err) + } + + err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + // The table should have: 0, 1, 2, π, NULL, "", "text", `blob`, NULL + stmt, _, err = db.Prepare(`SELECT col FROM test`) + if err != nil { + t.Fatal(err) + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != INTEGER { + t.Errorf("got %v, want INTEGER", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "0" { + t.Errorf("got %q, want zero", got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "0" { + t.Errorf("got %q, want zero", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != INTEGER { + t.Errorf("got %v, want INTEGER", got) + } + if got := stmt.ColumnBool(0); got != true { + t.Errorf("got %v, want true", got) + } + if got := stmt.ColumnInt(0); got != 1 { + t.Errorf("got %v, want one", got) + } + if got := stmt.ColumnFloat(0); got != 1 { + t.Errorf("got %v, want one", got) + } + if got := stmt.ColumnText(0); got != "1" { + t.Errorf("got %q, want one", got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "1" { + t.Errorf("got %q, want one", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != INTEGER { + t.Errorf("got %v, want INTEGER", got) + } + if got := stmt.ColumnBool(0); got != true { + t.Errorf("got %v, want true", got) + } + if got := stmt.ColumnInt(0); got != 2 { + t.Errorf("got %v, want two", got) + } + if got := stmt.ColumnFloat(0); got != 2 { + t.Errorf("got %v, want two", got) + } + if got := stmt.ColumnText(0); got != "2" { + t.Errorf("got %q, want two", got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "2" { + t.Errorf("got %q, want two", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != FLOAT { + t.Errorf("got %v, want FLOAT", got) + } + if got := stmt.ColumnBool(0); got != true { + t.Errorf("got %v, want true", got) + } + if got := stmt.ColumnInt(0); got != 3 { + t.Errorf("got %v, want three", got) + } + if got := stmt.ColumnFloat(0); got != math.Pi { + t.Errorf("got %v, want π", got) + } + if got := stmt.ColumnText(0); got != "3.14159265358979" { + t.Errorf("got %q, want π", got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" { + t.Errorf("got %q, want π", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != NULL { + t.Errorf("got %v, want NULL", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "" { + t.Errorf("got %q, want empty", got) + } + if got := stmt.ColumnBlob(0, nil); got != nil { + t.Errorf("got %q, want nil", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != TEXT { + t.Errorf("got %v, want TEXT", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "" { + t.Errorf("got %q, want empty", got) + } + if got := stmt.ColumnBlob(0, nil); got != nil { + t.Errorf("got %q, want nil", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != TEXT { + t.Errorf("got %v, want TEXT", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "text" { + t.Errorf(`got %q, want "text"`, got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "text" { + t.Errorf(`got %q, want "text"`, got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != BLOB { + t.Errorf("got %v, want BLOB", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "blob" { + t.Errorf(`got %q, want "blob"`, got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "blob" { + t.Errorf(`got %q, want "blob"`, got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != NULL { + t.Errorf("got %v, want NULL", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "" { + t.Errorf("got %q, want empty", got) + } + if got := stmt.ColumnBlob(0, nil); got != nil { + t.Errorf("got %q, want nil", got) + } + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + err = db.Close() + if err != nil { + t.Fatal(err) + } +} diff --git a/tests/dir_test.go b/tests/dir_test.go index dd9b71d..10ba8d3 100644 --- a/tests/dir_test.go +++ b/tests/dir_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" ) func TestDir(t *testing.T) { diff --git a/tests/parallel_test.go b/tests/parallel_test.go index e3c36f6..df6f077 100644 --- a/tests/parallel_test.go +++ b/tests/parallel_test.go @@ -5,7 +5,6 @@ import ( "path/filepath" "runtime" "testing" - "time" "golang.org/x/sync/errgroup" @@ -41,12 +40,12 @@ func TestParallel(t *testing.T) { err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`) if err != nil { - t.Fatal(err) + return err } err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`) if err != nil { - t.Fatal(err) + return err } return db.Close() @@ -104,7 +103,6 @@ func TestParallel(t *testing.T) { } else { group.Go(writer) } - time.Sleep(time.Microsecond) } err = group.Wait() if err != nil {