diff --git a/README.md b/README.md index 975f60d..1219ec9 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ Performance is tested by running - [x] nested transactions - [x] incremental BLOB I/O - [x] online backup + - [ ] JSON support - [ ] session extension - [ ] custom VFSes - [x] custom VFS API diff --git a/context.go b/context.go index 3e511b5..719ff8e 100644 --- a/context.go +++ b/context.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "encoding/json" "errors" "math" "time" @@ -138,6 +139,31 @@ func (c Context) resultRFC3339Nano(value time.Time) { uint64(c.api.destructor), _UTF8) } +// ResultJSON sets the result of the function to the JSON encoding of value. +// +// https://www.sqlite.org/c3ref/result_blob.html +func (c Context) ResultJSON(value any) { + data, err := json.Marshal(value) + if err != nil { + c.ResultError(err) + } + ptr := c.newBytes(data) + c.call(c.api.resultText, + uint64(c.handle), uint64(ptr), uint64(len(data)), + uint64(c.api.destructor)) +} + +// ResultValue sets the result of the function a copy of [Value]. +// +// https://www.sqlite.org/c3ref/result_blob.html +func (c Context) ResultValue(value Value) { + if value.sqlite != c.sqlite { + c.ResultError(MISUSE) + } + c.call(c.api.resultValue, + uint64(c.handle), uint64(value.handle)) +} + // ResultError sets the result of the function an error. // // https://www.sqlite.org/c3ref/result_blob.html diff --git a/embed/exports.txt b/embed/exports.txt index 94bfa66..4f10313 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -68,6 +68,7 @@ sqlite3_result_double sqlite3_result_text64 sqlite3_result_blob64 sqlite3_result_zeroblob64 +sqlite3_result_value sqlite3_result_error sqlite3_result_error_code sqlite3_result_error_nomem diff --git a/embed/sqlite3.wasm b/embed/sqlite3.wasm index a0e7afa..6db8758 100755 Binary files a/embed/sqlite3.wasm and b/embed/sqlite3.wasm differ diff --git a/gormlite/ddlmod.go b/gormlite/ddlmod.go index e024b7a..e0cd3fd 100644 --- a/gormlite/ddlmod.go +++ b/gormlite/ddlmod.go @@ -235,6 +235,7 @@ func (d *ddl) removeConstraint(name string) bool { return false } +//lint:ignore U1000 ignore unused code. func (d *ddl) hasConstraint(name string) bool { reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") diff --git a/internal/util/mem.go b/internal/util/mem.go index 0b35647..11f3735 100644 --- a/internal/util/mem.go +++ b/internal/util/mem.go @@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte { if size > math.MaxUint32 { panic(RangeErr) } + if size == 0 { + return nil + } buf, ok := mod.Memory().Read(ptr, uint32(size)) if !ok { panic(RangeErr) diff --git a/sqlite.go b/sqlite.go index f442a72..8a9a65d 100644 --- a/sqlite.go +++ b/sqlite.go @@ -170,6 +170,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) { resultText: getFun("sqlite3_result_text64"), resultBlob: getFun("sqlite3_result_blob64"), resultZeroBlob: getFun("sqlite3_result_zeroblob64"), + resultValue: getFun("sqlite3_result_value"), resultError: getFun("sqlite3_result_error"), resultErrorCode: getFun("sqlite3_result_error_code"), resultErrorMem: getFun("sqlite3_result_error_nomem"), @@ -200,13 +201,15 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error { err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) } - if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 { - err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) - } + if handle != 0 { + if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 { + err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) + } - if sql != nil { - if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 { - err.sql = sql[0][r:] + if sql != nil { + if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 { + err.sql = sql[0][r:] + } } } @@ -245,7 +248,7 @@ func (sqlt *sqlite) new(size uint64) uint32 { } func (sqlt *sqlite) newBytes(b []byte) uint32 { - if b == nil { + if (*[0]byte)(b) == nil { return 0 } ptr := sqlt.new(uint64(len(b))) @@ -386,6 +389,7 @@ type sqliteAPI struct { resultText api.Function resultBlob api.Function resultZeroBlob api.Function + resultValue api.Function resultError api.Function resultErrorCode api.Function resultErrorMem api.Function diff --git a/sqlite_test.go b/sqlite_test.go index 8057b03..2f2b0cf 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -132,6 +132,15 @@ func Test_sqlite_newBytes(t *testing.T) { if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) { t.Errorf("got %q, want %q", got, want) } + + ptr = sqlite.newBytes(buf[:0]) + if ptr == 0 { + t.Fatal("got nullptr, want a pointer") + } + + if got := util.View(sqlite.mod, ptr, 0); got != nil { + t.Errorf("got %q, want nil", got) + } } func Test_sqlite_newString(t *testing.T) { diff --git a/stmt.go b/stmt.go index b944e5d..23b9b41 100644 --- a/stmt.go +++ b/stmt.go @@ -1,7 +1,9 @@ package sqlite3 import ( + "encoding/json" "math" + "strconv" "time" "github.com/ncruces/go-sqlite3/internal/util" @@ -248,6 +250,23 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error { return s.c.error(r) } +// BindJSON binds the JSON encoding of value to the prepared statement. +// The leftmost SQL parameter has an index of 1. +// +// https://www.sqlite.org/c3ref/bind_blob.html +func (s *Stmt) BindJSON(param int, value any) error { + data, err := json.Marshal(value) + if err != nil { + return err + } + ptr := s.c.newBytes(data) + r := s.c.call(s.c.api.bindText, + uint64(s.handle), uint64(param), + uint64(ptr), uint64(len(data)), + uint64(s.c.api.destructor)) + return s.c.error(r) +} + // ColumnCount returns the number of columns in a result set. // // https://www.sqlite.org/c3ref/column_count.html @@ -402,6 +421,28 @@ func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte { return util.View(s.c.mod, ptr, r) } +// ColumnJSON parses the JSON-encoded value of the result column +// and stores it in the value pointed to by ptr. +// The leftmost column of the result set has the index 0. +// +// https://www.sqlite.org/c3ref/column_blob.html +func (s *Stmt) ColumnJSON(col int, ptr any) error { + var data []byte + switch s.ColumnType(col) { + case NULL: + data = []byte("null") + case TEXT, BLOB: + data = s.ColumnRawBlob(col) + case INTEGER: + data = strconv.AppendInt(nil, s.ColumnInt64(col), 10) + case FLOAT: + data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64) + default: + panic(util.AssertErr()) + } + return json.Unmarshal(data, ptr) +} + // Return true if stmt is an empty SQL statement. // This is used as an optimization. // It's OK to always return false here. diff --git a/tests/func_test.go b/tests/func_test.go index b2e6768..efeb129 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -36,8 +36,17 @@ func TestCreateFunction(t *testing.T) { case 7: ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault) case 8: - ctx.ResultNull() + var v any + if err := arg.JSON(&v); err != nil { + ctx.ResultError(err) + } else { + ctx.ResultJSON(v) + } case 9: + ctx.ResultValue(arg) + case 10: + ctx.ResultNull() + case 11: ctx.ResultError(sqlite3.FULL) } }) @@ -45,7 +54,7 @@ func TestCreateFunction(t *testing.T) { t.Fatal(err) } - stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`) + stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`) if err != nil { t.Error(err) } @@ -123,6 +132,27 @@ func TestCreateFunction(t *testing.T) { } } + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.TEXT { + t.Errorf("got %v, want TEXT", got) + } + var got int + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != 8 { + t.Errorf("got %v, want 8", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.INTEGER { + t.Errorf("got %v, want INTEGER", got) + } + if got := stmt.ColumnInt64(0); got != 9 { + t.Errorf("got %v, want 9", got) + } + } + if stmt.Step() { if got := stmt.ColumnType(0); got != sqlite3.NULL { t.Errorf("got %v, want NULL", got) diff --git a/tests/stmt_test.go b/tests/stmt_test.go index 3c2cb65..766472c 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -1,6 +1,7 @@ package tests import ( + "encoding/json" "math" "testing" "time" @@ -81,6 +82,13 @@ func TestStmt(t *testing.T) { t.Fatal(err) } + if err := stmt.BindBlob(1, []byte("")); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { + t.Fatal(err) + } + if err := stmt.BindBlob(1, []byte("blob")); err != nil { t.Fatal(err) } @@ -102,6 +110,13 @@ func TestStmt(t *testing.T) { t.Fatal(err) } + if err := stmt.BindJSON(1, true); err != nil { + t.Fatal(err) + } + if err := stmt.Exec(); err != nil { + t.Fatal(err) + } + if err := stmt.ClearBindings(); err != nil { t.Fatal(err) } @@ -114,7 +129,7 @@ func TestStmt(t *testing.T) { t.Fatal(err) } - // The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL + // The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL stmt, _, err = db.Prepare(`SELECT col FROM test`) if err != nil { t.Fatal(err) @@ -140,6 +155,12 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); string(got) != "0" { t.Errorf("got %q, want zero", got) } + var got int + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != 0 { + t.Errorf("got %v, want zero", got) + } } if stmt.Step() { @@ -161,6 +182,12 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); string(got) != "1" { t.Errorf("got %q, want one", got) } + var got float32 + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != 1 { + t.Errorf("got %v, want one", got) + } } if stmt.Step() { @@ -182,6 +209,12 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); string(got) != "2" { t.Errorf("got %q, want two", got) } + var got json.Number + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != "2" { + t.Errorf("got %v, want two", got) + } } if stmt.Step() { @@ -203,6 +236,12 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" { t.Errorf("got %q, want π", got) } + var got float64 + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != math.Pi { + t.Errorf("got %v, want π", got) + } } if stmt.Step() { @@ -224,6 +263,12 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); got != nil { t.Errorf("got %q, want nil", got) } + var got any = 1 + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != nil { + t.Errorf("got %v, want NULL", got) + } } if stmt.Step() { @@ -245,6 +290,10 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); got != nil { t.Errorf("got %q, want nil", got) } + var got any + if err := stmt.ColumnJSON(0, &got); err == nil { + t.Errorf("got %v, want error", got) + } } if stmt.Step() { @@ -266,6 +315,35 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); string(got) != "text" { t.Errorf(`got %q, want "text"`, got) } + var got any + if err := stmt.ColumnJSON(0, &got); err == nil { + t.Errorf("got %v, want error", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.BLOB { + t.Errorf("got %v, want BLOB", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "" { + t.Errorf("got %q, want empty", got) + } + if got := stmt.ColumnBlob(0, nil); got != nil { + t.Errorf("got %q, want nil", got) + } + var got any + if err := stmt.ColumnJSON(0, &got); err == nil { + t.Errorf("got %v, want error", got) + } } if stmt.Step() { @@ -287,6 +365,10 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); string(got) != "blob" { t.Errorf(`got %q, want "blob"`, got) } + var got any + if err := stmt.ColumnJSON(0, &got); err == nil { + t.Errorf("got %v, want error", got) + } } if stmt.Step() { @@ -308,6 +390,12 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); got != nil { t.Errorf("got %q, want nil", got) } + var got any = 1 + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != nil { + t.Errorf("got %v, want NULL", got) + } } if stmt.Step() { @@ -329,6 +417,37 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" { t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got) } + var got any + if err := stmt.ColumnJSON(0, &got); err == nil { + t.Errorf("got %v, want error", got) + } + } + + if stmt.Step() { + if got := stmt.ColumnType(0); got != sqlite3.TEXT { + t.Errorf("got %v, want TEXT", got) + } + if got := stmt.ColumnBool(0); got != false { + t.Errorf("got %v, want false", got) + } + if got := stmt.ColumnInt(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnFloat(0); got != 0 { + t.Errorf("got %v, want zero", got) + } + if got := stmt.ColumnText(0); got != "true" { + t.Errorf("got %q, want true", got) + } + if got := stmt.ColumnBlob(0, nil); string(got) != "true" { + t.Errorf("got %q, want true", got) + } + var got any = 1 + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != true { + t.Errorf("got %v, want true", got) + } } if stmt.Step() { @@ -350,6 +469,12 @@ func TestStmt(t *testing.T) { if got := stmt.ColumnBlob(0, nil); got != nil { t.Errorf("got %q, want nil", got) } + var got any = 1 + if err := stmt.ColumnJSON(0, &got); err != nil { + t.Error(err) + } else if got != nil { + t.Errorf("got %v, want NULL", got) + } } if err := stmt.Close(); err != nil { diff --git a/value.go b/value.go index aed1056..6db8d15 100644 --- a/value.go +++ b/value.go @@ -1,7 +1,9 @@ package sqlite3 import ( + "encoding/json" "math" + "strconv" "time" "github.com/ncruces/go-sqlite3/internal/util" @@ -123,3 +125,22 @@ func (v Value) rawBytes(ptr uint32) []byte { r := v.call(v.api.valueBytes, uint64(v.handle)) return util.View(v.mod, ptr, r) } + +// JSON parses a JSON-encoded value +// and stores the result in the value pointed to by ptr. +func (v Value) JSON(ptr any) error { + var data []byte + switch v.Type() { + case NULL: + data = []byte("null") + case TEXT, BLOB: + data = v.RawBlob() + case INTEGER: + data = strconv.AppendInt(nil, v.Int64(), 10) + case FLOAT: + data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64) + default: + panic(util.AssertErr()) + } + return json.Unmarshal(data, ptr) +}