This commit is contained in:
Nuno Cruces
2023-02-08 00:00:53 +00:00
parent 770ab8a073
commit 4597acc49d
8 changed files with 394 additions and 8 deletions

View File

@@ -20,6 +20,7 @@ jobs:
uses: actions/setup-go@v3
with:
go-version: stable
cache: true
- name: Build
run: go build -v ./...

View File

@@ -1,5 +1,7 @@
package sqlite3
import "strconv"
const (
_OK = 0 /* Successful result */
_ROW = 100 /* sqlite3_step() has another row ready */
@@ -175,3 +177,20 @@ const (
BLOB Datatype = 4
NULL Datatype = 5
)
func (t Datatype) String() string {
const name = "INTEGERFLOATTEXTBLOBNULL"
switch t {
case INTEGER:
return name[0:7]
case FLOAT:
return name[7:12]
case TEXT:
return name[12:16]
case BLOB:
return name[16:20]
case NULL:
return name[20:24]
}
return strconv.FormatUint(uint64(t), 10)
}

24
const_test.go Normal file
View File

@@ -0,0 +1,24 @@
package sqlite3
import "testing"
func TestDatatype_String(t *testing.T) {
tests := []struct {
data Datatype
want string
}{
{INTEGER, "INTEGER"},
{FLOAT, "FLOAT"},
{TEXT, "TEXT"},
{BLOB, "BLOB"},
{NULL, "NULL"},
{10, "10"},
}
for _, tt := range tests {
t.Run(tt.want, func(t *testing.T) {
if got := tt.data.String(); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}

3
go.mod
View File

@@ -5,6 +5,5 @@ go 1.19
require (
github.com/ncruces/julianday v0.1.4
github.com/tetratelabs/wazero v1.0.0-pre.8
golang.org/x/sync v0.1.0
)
require golang.org/x/sync v0.1.0

23
stmt.go
View File

@@ -49,6 +49,16 @@ func (s *Stmt) Err() error {
return s.err
}
func (s *Stmt) Exec() error {
for s.Step() {
}
err := s.Err()
if rerr := s.Reset(); err == nil {
err = rerr
}
return err
}
func (s *Stmt) BindBool(param int, value bool) error {
if value {
return s.BindInt64(param, 1)
@@ -111,6 +121,15 @@ func (s *Stmt) BindNull(param int) error {
return s.c.error(r[0])
}
func (s *Stmt) ColumnType(col int) Datatype {
r, err := s.c.api.columnType.Call(s.c.ctx,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
}
return Datatype(r[0])
}
func (s *Stmt) ColumnBool(col int) bool {
if i := s.ColumnInt64(col); i != 0 {
return true
@@ -132,7 +151,7 @@ func (s *Stmt) ColumnInt64(col int) int64 {
}
func (s *Stmt) ColumnFloat(col int) float64 {
r, err := s.c.api.columnInteger.Call(s.c.ctx,
r, err := s.c.api.columnFloat.Call(s.c.ctx,
uint64(s.handle), uint64(col))
if err != nil {
panic(err)
@@ -181,7 +200,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
panic(err)
}
s.err = s.c.error(r[0])
return nil
return buf[0:0]
}
r, err = s.c.api.columnBytes.Call(s.c.ctx,

325
stmt_test.go Normal file
View File

@@ -0,0 +1,325 @@
package sqlite3
import (
"math"
"testing"
)
func TestStmt(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`INSERT INTO test(col) VALUES(?)`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
err = stmt.BindBool(1, false)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBool(1, true)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindInt(1, 2)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindFloat(1, math.Pi)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindNull(1)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindText(1, "text")
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, []byte("blob"))
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.BindBlob(1, nil)
if err != nil {
t.Fatal(err)
}
err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
// The table should have: 0, 1, 2, π, NULL, "", "text", `blob`, NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "0" {
t.Errorf("got %q, want zero", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 1 {
t.Errorf("got %v, want one", got)
}
if got := stmt.ColumnFloat(0); got != 1 {
t.Errorf("got %v, want one", got)
}
if got := stmt.ColumnText(0); got != "1" {
t.Errorf("got %q, want one", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 2 {
t.Errorf("got %v, want two", got)
}
if got := stmt.ColumnFloat(0); got != 2 {
t.Errorf("got %v, want two", got)
}
if got := stmt.ColumnText(0); got != "2" {
t.Errorf("got %q, want two", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != FLOAT {
t.Errorf("got %v, want FLOAT", got)
}
if got := stmt.ColumnBool(0); got != true {
t.Errorf("got %v, want true", got)
}
if got := stmt.ColumnInt(0); got != 3 {
t.Errorf("got %v, want three", got)
}
if got := stmt.ColumnFloat(0); got != math.Pi {
t.Errorf("got %v, want π", got)
}
if got := stmt.ColumnText(0); got != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "text" {
t.Errorf(`got %q, want "text"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != NULL {
t.Errorf("got %v, want NULL", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
err = db.Close()
if err != nil {
t.Fatal(err)
}
}

View File

@@ -5,6 +5,7 @@ import (
"testing"
"github.com/ncruces/go-sqlite3"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDir(t *testing.T) {

View File

@@ -5,7 +5,6 @@ import (
"path/filepath"
"runtime"
"testing"
"time"
"golang.org/x/sync/errgroup"
@@ -41,12 +40,12 @@ func TestParallel(t *testing.T) {
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
return err
}
err = db.Exec(`INSERT INTO users(id, name) VALUES(0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
return err
}
return db.Close()
@@ -104,7 +103,6 @@ func TestParallel(t *testing.T) {
} else {
group.Go(writer)
}
time.Sleep(time.Microsecond)
}
err = group.Wait()
if err != nil {