Towards JSON.

This commit is contained in:
Nuno Cruces
2023-10-13 17:06:05 +01:00
parent f6d77f3cf4
commit eec45ea684
12 changed files with 272 additions and 10 deletions

View File

@@ -86,6 +86,7 @@ Performance is tested by running
- [x] nested transactions - [x] nested transactions
- [x] incremental BLOB I/O - [x] incremental BLOB I/O
- [x] online backup - [x] online backup
- [ ] JSON support
- [ ] session extension - [ ] session extension
- [ ] custom VFSes - [ ] custom VFSes
- [x] custom VFS API - [x] custom VFS API

View File

@@ -1,6 +1,7 @@
package sqlite3 package sqlite3
import ( import (
"encoding/json"
"errors" "errors"
"math" "math"
"time" "time"
@@ -138,6 +139,31 @@ func (c Context) resultRFC3339Nano(value time.Time) {
uint64(c.api.destructor), _UTF8) uint64(c.api.destructor), _UTF8)
} }
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
c.ResultError(err)
}
ptr := c.newBytes(data)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(data)),
uint64(c.api.destructor))
}
// ResultValue sets the result of the function a copy of [Value].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultValue(value Value) {
if value.sqlite != c.sqlite {
c.ResultError(MISUSE)
}
c.call(c.api.resultValue,
uint64(c.handle), uint64(value.handle))
}
// ResultError sets the result of the function an error. // ResultError sets the result of the function an error.
// //
// https://www.sqlite.org/c3ref/result_blob.html // https://www.sqlite.org/c3ref/result_blob.html

View File

@@ -68,6 +68,7 @@ sqlite3_result_double
sqlite3_result_text64 sqlite3_result_text64
sqlite3_result_blob64 sqlite3_result_blob64
sqlite3_result_zeroblob64 sqlite3_result_zeroblob64
sqlite3_result_value
sqlite3_result_error sqlite3_result_error
sqlite3_result_error_code sqlite3_result_error_code
sqlite3_result_error_nomem sqlite3_result_error_nomem

Binary file not shown.

View File

@@ -235,6 +235,7 @@ func (d *ddl) removeConstraint(name string) bool {
return false return false
} }
//lint:ignore U1000 ignore unused code.
func (d *ddl) hasConstraint(name string) bool { func (d *ddl) hasConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")

View File

@@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte {
if size > math.MaxUint32 { if size > math.MaxUint32 {
panic(RangeErr) panic(RangeErr)
} }
if size == 0 {
return nil
}
buf, ok := mod.Memory().Read(ptr, uint32(size)) buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok { if !ok {
panic(RangeErr) panic(RangeErr)

View File

@@ -170,6 +170,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
resultText: getFun("sqlite3_result_text64"), resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"), resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"), resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"), resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"), resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"), resultErrorMem: getFun("sqlite3_result_error_nomem"),
@@ -200,13 +201,15 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
} }
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 { if handle != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING) if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
} err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil { if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 { if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:] err.sql = sql[0][r:]
}
} }
} }
@@ -245,7 +248,7 @@ func (sqlt *sqlite) new(size uint64) uint32 {
} }
func (sqlt *sqlite) newBytes(b []byte) uint32 { func (sqlt *sqlite) newBytes(b []byte) uint32 {
if b == nil { if (*[0]byte)(b) == nil {
return 0 return 0
} }
ptr := sqlt.new(uint64(len(b))) ptr := sqlt.new(uint64(len(b)))
@@ -386,6 +389,7 @@ type sqliteAPI struct {
resultText api.Function resultText api.Function
resultBlob api.Function resultBlob api.Function
resultZeroBlob api.Function resultZeroBlob api.Function
resultValue api.Function
resultError api.Function resultError api.Function
resultErrorCode api.Function resultErrorCode api.Function
resultErrorMem api.Function resultErrorMem api.Function

View File

@@ -132,6 +132,15 @@ func Test_sqlite_newBytes(t *testing.T) {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) { if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want) t.Errorf("got %q, want %q", got, want)
} }
ptr = sqlite.newBytes(buf[:0])
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
if got := util.View(sqlite.mod, ptr, 0); got != nil {
t.Errorf("got %q, want nil", got)
}
} }
func Test_sqlite_newString(t *testing.T) { func Test_sqlite_newString(t *testing.T) {

41
stmt.go
View File

@@ -1,7 +1,9 @@
package sqlite3 package sqlite3
import ( import (
"encoding/json"
"math" "math"
"strconv"
"time" "time"
"github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/internal/util"
@@ -248,6 +250,23 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
return s.c.error(r) return s.c.error(r)
} }
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
ptr := s.c.newBytes(data)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(data)),
uint64(s.c.api.destructor))
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set. // ColumnCount returns the number of columns in a result set.
// //
// https://www.sqlite.org/c3ref/column_count.html // https://www.sqlite.org/c3ref/column_count.html
@@ -402,6 +421,28 @@ func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
return util.View(s.c.mod, ptr, r) return util.View(s.c.mod, ptr, r)
} }
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = []byte("null")
case TEXT, BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// Return true if stmt is an empty SQL statement. // Return true if stmt is an empty SQL statement.
// This is used as an optimization. // This is used as an optimization.
// It's OK to always return false here. // It's OK to always return false here.

View File

@@ -36,8 +36,17 @@ func TestCreateFunction(t *testing.T) {
case 7: case 7:
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault) ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
case 8: case 8:
ctx.ResultNull() var v any
if err := arg.JSON(&v); err != nil {
ctx.ResultError(err)
} else {
ctx.ResultJSON(v)
}
case 9: case 9:
ctx.ResultValue(arg)
case 10:
ctx.ResultNull()
case 11:
ctx.ResultError(sqlite3.FULL) ctx.ResultError(sqlite3.FULL)
} }
}) })
@@ -45,7 +54,7 @@ func TestCreateFunction(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`) stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
@@ -123,6 +132,27 @@ func TestCreateFunction(t *testing.T) {
} }
} }
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 8 {
t.Errorf("got %v, want 8", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 9 {
t.Errorf("got %v, want 9", got)
}
}
if stmt.Step() { if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL { if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got) t.Errorf("got %v, want NULL", got)

View File

@@ -1,6 +1,7 @@
package tests package tests
import ( import (
"encoding/json"
"math" "math"
"testing" "testing"
"time" "time"
@@ -81,6 +82,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if err := stmt.BindBlob(1, []byte("")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("blob")); err != nil { if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -102,6 +110,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
if err := stmt.BindJSON(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.ClearBindings(); err != nil { if err := stmt.ClearBindings(); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -114,7 +129,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL // The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`) stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -140,6 +155,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "0" { if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got) t.Errorf("got %q, want zero", got)
} }
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 0 {
t.Errorf("got %v, want zero", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -161,6 +182,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "1" { if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got) t.Errorf("got %q, want one", got)
} }
var got float32
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 1 {
t.Errorf("got %v, want one", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -182,6 +209,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "2" { if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got) t.Errorf("got %q, want two", got)
} }
var got json.Number
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != "2" {
t.Errorf("got %v, want two", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -203,6 +236,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" { if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got) t.Errorf("got %q, want π", got)
} }
var got float64
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != math.Pi {
t.Errorf("got %v, want π", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -224,6 +263,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil { if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got) t.Errorf("got %q, want nil", got)
} }
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -245,6 +290,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil { if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got) t.Errorf("got %q, want nil", got)
} }
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -266,6 +315,35 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "text" { if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got) t.Errorf(`got %q, want "text"`, got)
} }
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -287,6 +365,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" { if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got) t.Errorf(`got %q, want "blob"`, got)
} }
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -308,6 +390,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil { if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got) t.Errorf("got %q, want nil", got)
} }
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -329,6 +417,37 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" { if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got) t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
} }
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "true" {
t.Errorf("got %q, want true", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "true" {
t.Errorf("got %q, want true", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != true {
t.Errorf("got %v, want true", got)
}
} }
if stmt.Step() { if stmt.Step() {
@@ -350,6 +469,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil { if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got) t.Errorf("got %q, want nil", got)
} }
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
} }
if err := stmt.Close(); err != nil { if err := stmt.Close(); err != nil {

View File

@@ -1,7 +1,9 @@
package sqlite3 package sqlite3
import ( import (
"encoding/json"
"math" "math"
"strconv"
"time" "time"
"github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/internal/util"
@@ -123,3 +125,22 @@ func (v Value) rawBytes(ptr uint32) []byte {
r := v.call(v.api.valueBytes, uint64(v.handle)) r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r) return util.View(v.mod, ptr, r)
} }
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = []byte("null")
case TEXT, BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}