Towards JSON.

This commit is contained in:
Nuno Cruces
2023-10-13 17:06:05 +01:00
parent f6d77f3cf4
commit eec45ea684
12 changed files with 272 additions and 10 deletions

View File

@@ -86,6 +86,7 @@ Performance is tested by running
- [x] nested transactions
- [x] incremental BLOB I/O
- [x] online backup
- [ ] JSON support
- [ ] session extension
- [ ] custom VFSes
- [x] custom VFS API

View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"encoding/json"
"errors"
"math"
"time"
@@ -138,6 +139,31 @@ func (c Context) resultRFC3339Nano(value time.Time) {
uint64(c.api.destructor), _UTF8)
}
// ResultJSON sets the result of the function to the JSON encoding of value.
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultJSON(value any) {
data, err := json.Marshal(value)
if err != nil {
c.ResultError(err)
}
ptr := c.newBytes(data)
c.call(c.api.resultText,
uint64(c.handle), uint64(ptr), uint64(len(data)),
uint64(c.api.destructor))
}
// ResultValue sets the result of the function a copy of [Value].
//
// https://www.sqlite.org/c3ref/result_blob.html
func (c Context) ResultValue(value Value) {
if value.sqlite != c.sqlite {
c.ResultError(MISUSE)
}
c.call(c.api.resultValue,
uint64(c.handle), uint64(value.handle))
}
// ResultError sets the result of the function an error.
//
// https://www.sqlite.org/c3ref/result_blob.html

View File

@@ -68,6 +68,7 @@ sqlite3_result_double
sqlite3_result_text64
sqlite3_result_blob64
sqlite3_result_zeroblob64
sqlite3_result_value
sqlite3_result_error
sqlite3_result_error_code
sqlite3_result_error_nomem

Binary file not shown.

View File

@@ -235,6 +235,7 @@ func (d *ddl) removeConstraint(name string) bool {
return false
}
//lint:ignore U1000 ignore unused code.
func (d *ddl) hasConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")

View File

@@ -14,6 +14,9 @@ func View(mod api.Module, ptr uint32, size uint64) []byte {
if size > math.MaxUint32 {
panic(RangeErr)
}
if size == 0 {
return nil
}
buf, ok := mod.Memory().Read(ptr, uint32(size))
if !ok {
panic(RangeErr)

View File

@@ -170,6 +170,7 @@ func instantiateSQLite() (sqlt *sqlite, err error) {
resultText: getFun("sqlite3_result_text64"),
resultBlob: getFun("sqlite3_result_blob64"),
resultZeroBlob: getFun("sqlite3_result_zeroblob64"),
resultValue: getFun("sqlite3_result_value"),
resultError: getFun("sqlite3_result_error"),
resultErrorCode: getFun("sqlite3_result_error_code"),
resultErrorMem: getFun("sqlite3_result_error_nomem"),
@@ -200,13 +201,15 @@ func (sqlt *sqlite) error(rc uint64, handle uint32, sql ...string) error {
err.str = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if handle != 0 {
if r := sqlt.call(sqlt.api.errmsg, uint64(handle)); r != 0 {
err.msg = util.ReadString(sqlt.mod, uint32(r), _MAX_STRING)
}
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
if sql != nil {
if r := sqlt.call(sqlt.api.erroff, uint64(handle)); r != math.MaxUint32 {
err.sql = sql[0][r:]
}
}
}
@@ -245,7 +248,7 @@ func (sqlt *sqlite) new(size uint64) uint32 {
}
func (sqlt *sqlite) newBytes(b []byte) uint32 {
if b == nil {
if (*[0]byte)(b) == nil {
return 0
}
ptr := sqlt.new(uint64(len(b)))
@@ -386,6 +389,7 @@ type sqliteAPI struct {
resultText api.Function
resultBlob api.Function
resultZeroBlob api.Function
resultValue api.Function
resultError api.Function
resultErrorCode api.Function
resultErrorMem api.Function

View File

@@ -132,6 +132,15 @@ func Test_sqlite_newBytes(t *testing.T) {
if got := util.View(sqlite.mod, ptr, uint64(len(want))); !bytes.Equal(got, want) {
t.Errorf("got %q, want %q", got, want)
}
ptr = sqlite.newBytes(buf[:0])
if ptr == 0 {
t.Fatal("got nullptr, want a pointer")
}
if got := util.View(sqlite.mod, ptr, 0); got != nil {
t.Errorf("got %q, want nil", got)
}
}
func Test_sqlite_newString(t *testing.T) {

41
stmt.go
View File

@@ -1,7 +1,9 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -248,6 +250,23 @@ func (s *Stmt) bindRFC3339Nano(param int, value time.Time) error {
return s.c.error(r)
}
// BindJSON binds the JSON encoding of value to the prepared statement.
// The leftmost SQL parameter has an index of 1.
//
// https://www.sqlite.org/c3ref/bind_blob.html
func (s *Stmt) BindJSON(param int, value any) error {
data, err := json.Marshal(value)
if err != nil {
return err
}
ptr := s.c.newBytes(data)
r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(data)),
uint64(s.c.api.destructor))
return s.c.error(r)
}
// ColumnCount returns the number of columns in a result set.
//
// https://www.sqlite.org/c3ref/column_count.html
@@ -402,6 +421,28 @@ func (s *Stmt) columnRawBytes(col int, ptr uint32) []byte {
return util.View(s.c.mod, ptr, r)
}
// ColumnJSON parses the JSON-encoded value of the result column
// and stores it in the value pointed to by ptr.
// The leftmost column of the result set has the index 0.
//
// https://www.sqlite.org/c3ref/column_blob.html
func (s *Stmt) ColumnJSON(col int, ptr any) error {
var data []byte
switch s.ColumnType(col) {
case NULL:
data = []byte("null")
case TEXT, BLOB:
data = s.ColumnRawBlob(col)
case INTEGER:
data = strconv.AppendInt(nil, s.ColumnInt64(col), 10)
case FLOAT:
data = strconv.AppendFloat(nil, s.ColumnFloat(col), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.

View File

@@ -36,8 +36,17 @@ func TestCreateFunction(t *testing.T) {
case 7:
ctx.ResultTime(arg.Time(sqlite3.TimeFormatUnix), sqlite3.TimeFormatDefault)
case 8:
ctx.ResultNull()
var v any
if err := arg.JSON(&v); err != nil {
ctx.ResultError(err)
} else {
ctx.ResultJSON(v)
}
case 9:
ctx.ResultValue(arg)
case 10:
ctx.ResultNull()
case 11:
ctx.ResultError(sqlite3.FULL)
}
})
@@ -45,7 +54,7 @@ func TestCreateFunction(t *testing.T) {
t.Fatal(err)
}
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0, 9)`)
stmt, _, err := db.Prepare(`SELECT test(value) FROM generate_series(0)`)
if err != nil {
t.Error(err)
}
@@ -123,6 +132,27 @@ func TestCreateFunction(t *testing.T) {
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 8 {
t.Errorf("got %v, want 8", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.INTEGER {
t.Errorf("got %v, want INTEGER", got)
}
if got := stmt.ColumnInt64(0); got != 9 {
t.Errorf("got %v, want 9", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.NULL {
t.Errorf("got %v, want NULL", got)

View File

@@ -1,6 +1,7 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
@@ -81,6 +82,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("")); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.BindBlob(1, []byte("blob")); err != nil {
t.Fatal(err)
}
@@ -102,6 +110,13 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
if err := stmt.BindJSON(1, true); err != nil {
t.Fatal(err)
}
if err := stmt.Exec(); err != nil {
t.Fatal(err)
}
if err := stmt.ClearBindings(); err != nil {
t.Fatal(err)
}
@@ -114,7 +129,7 @@ func TestStmt(t *testing.T) {
t.Fatal(err)
}
// The table should have: 0, 1, 2, π, NULL, "", "text", "blob", NULL, "\0\0\0\0", NULL
// The table should have: 0, 1, 2, π, NULL, "", "text", "", "blob", NULL, "\0\0\0\0", "true", NULL
stmt, _, err = db.Prepare(`SELECT col FROM test`)
if err != nil {
t.Fatal(err)
@@ -140,6 +155,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "0" {
t.Errorf("got %q, want zero", got)
}
var got int
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 0 {
t.Errorf("got %v, want zero", got)
}
}
if stmt.Step() {
@@ -161,6 +182,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "1" {
t.Errorf("got %q, want one", got)
}
var got float32
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != 1 {
t.Errorf("got %v, want one", got)
}
}
if stmt.Step() {
@@ -182,6 +209,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "2" {
t.Errorf("got %q, want two", got)
}
var got json.Number
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != "2" {
t.Errorf("got %v, want two", got)
}
}
if stmt.Step() {
@@ -203,6 +236,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "3.14159265358979" {
t.Errorf("got %q, want π", got)
}
var got float64
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != math.Pi {
t.Errorf("got %v, want π", got)
}
}
if stmt.Step() {
@@ -224,6 +263,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -245,6 +290,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -266,6 +315,35 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "text" {
t.Errorf(`got %q, want "text"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.BLOB {
t.Errorf("got %v, want BLOB", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "" {
t.Errorf("got %q, want empty", got)
}
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -287,6 +365,10 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "blob" {
t.Errorf(`got %q, want "blob"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
@@ -308,6 +390,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if stmt.Step() {
@@ -329,6 +417,37 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); string(got) != "\x00\x00\x00\x00" {
t.Errorf(`got %q, want "\x00\x00\x00\x00"`, got)
}
var got any
if err := stmt.ColumnJSON(0, &got); err == nil {
t.Errorf("got %v, want error", got)
}
}
if stmt.Step() {
if got := stmt.ColumnType(0); got != sqlite3.TEXT {
t.Errorf("got %v, want TEXT", got)
}
if got := stmt.ColumnBool(0); got != false {
t.Errorf("got %v, want false", got)
}
if got := stmt.ColumnInt(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnFloat(0); got != 0 {
t.Errorf("got %v, want zero", got)
}
if got := stmt.ColumnText(0); got != "true" {
t.Errorf("got %q, want true", got)
}
if got := stmt.ColumnBlob(0, nil); string(got) != "true" {
t.Errorf("got %q, want true", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != true {
t.Errorf("got %v, want true", got)
}
}
if stmt.Step() {
@@ -350,6 +469,12 @@ func TestStmt(t *testing.T) {
if got := stmt.ColumnBlob(0, nil); got != nil {
t.Errorf("got %q, want nil", got)
}
var got any = 1
if err := stmt.ColumnJSON(0, &got); err != nil {
t.Error(err)
} else if got != nil {
t.Errorf("got %v, want NULL", got)
}
}
if err := stmt.Close(); err != nil {

View File

@@ -1,7 +1,9 @@
package sqlite3
import (
"encoding/json"
"math"
"strconv"
"time"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -123,3 +125,22 @@ func (v Value) rawBytes(ptr uint32) []byte {
r := v.call(v.api.valueBytes, uint64(v.handle))
return util.View(v.mod, ptr, r)
}
// JSON parses a JSON-encoded value
// and stores the result in the value pointed to by ptr.
func (v Value) JSON(ptr any) error {
var data []byte
switch v.Type() {
case NULL:
data = []byte("null")
case TEXT, BLOB:
data = v.RawBlob()
case INTEGER:
data = strconv.AppendInt(nil, v.Int64(), 10)
case FLOAT:
data = strconv.AppendFloat(nil, v.Float(), 'g', -1, 64)
default:
panic(util.AssertErr())
}
return json.Unmarshal(data, ptr)
}