JSON support.

This commit is contained in:
Nuno Cruces
2023-10-18 23:14:46 +01:00
parent 61da30f44a
commit d60fceac92
6 changed files with 134 additions and 5 deletions

View File

@@ -82,7 +82,7 @@ Performance is tested by running
- [x] nested transactions - [x] nested transactions
- [x] incremental BLOB I/O - [x] incremental BLOB I/O
- [x] online backup - [x] online backup
- [ ] JSON support - [x] JSON support
- [ ] session extension - [ ] session extension
- [ ] custom VFSes - [ ] custom VFSes
- [x] custom VFS API - [x] custom VFS API

View File

@@ -30,6 +30,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/url" "net/url"
@@ -272,6 +273,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return newResult(c.Conn), nil return newResult(c.Conn), nil
} }
func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
return nil
}
type stmt struct { type stmt struct {
Stmt *sqlite3.Stmt Stmt *sqlite3.Stmt
Conn *sqlite3.Conn Conn *sqlite3.Conn
@@ -370,6 +375,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
err = s.Stmt.BindZeroBlob(id, int64(a)) err = s.Stmt.BindZeroBlob(id, int64(a))
case time.Time: case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault) err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
case json.Marshaler:
err = s.Stmt.BindJSON(id, a)
case nil: case nil:
err = s.Stmt.BindNull(id) err = s.Stmt.BindNull(id)
default: default:
@@ -386,7 +393,7 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error { func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
switch arg.Value.(type) { switch arg.Value.(type) {
case bool, int, int64, float64, string, []byte, case bool, int, int64, float64, string, []byte,
sqlite3.ZeroBlob, time.Time, nil: sqlite3.ZeroBlob, time.Time, json.Marshaler, nil:
return nil return nil
default: default:
return driver.ErrSkip return driver.ErrSkip

53
json.go Normal file
View File

@@ -0,0 +1,53 @@
package sqlite3
import (
"encoding/json"
"strconv"
"time"
"unsafe"
)
// JSON returns:
// a [json.Marshaler] that can be used as an argument to
// [database/sql.DB.Exec] and similar methods to
// store value as JSON; and
// a [database/sql.Scanner] that can be used as an argument to
// [database/sql.Row.Scan] and similar methods to
// decode JSON into value.
func JSON(value any) any {
return jsonValue{value}
}
type jsonValue struct{ any }
func (j jsonValue) MarshalJSON() ([]byte, error) {
return json.Marshal(j.any)
}
func (j jsonValue) UnmarshalJSON(data []byte) error {
return json.Unmarshal(data, j.any)
}
func (j jsonValue) Scan(value any) error {
var mem [40]byte
buf := mem[:0]
switch v := value.(type) {
case []byte:
buf = v
case string:
buf = unsafe.Slice(unsafe.StringData(v), len(v))
case int64:
buf = strconv.AppendInt(nil, v, 10)
case float64:
buf = strconv.AppendFloat(nil, v, 'g', -1, 64)
case time.Time:
buf = append(buf, '"')
buf = v.AppendFormat(buf, time.RFC3339Nano)
buf = append(buf, '"')
case nil:
buf = append(buf, "null"...)
}
return j.UnmarshalJSON(buf)
}

View File

@@ -263,7 +263,7 @@ func (s *Stmt) BindJSON(param int, value any) error {
r := s.c.call(s.c.api.bindText, r := s.c.call(s.c.api.bindText,
uint64(s.handle), uint64(param), uint64(s.handle), uint64(param),
uint64(ptr), uint64(len(data)), uint64(ptr), uint64(len(data)),
uint64(s.c.api.destructor)) uint64(s.c.api.destructor), _UTF8)
return s.c.error(r) return s.c.error(r)
} }

68
tests/json_test.go Normal file
View File

@@ -0,0 +1,68 @@
package tests
import (
"encoding/json"
"math"
"testing"
"time"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
"github.com/ncruces/julianday"
)
func TestJSON(t *testing.T) {
t.Parallel()
db, err := driver.Open(":memory:", nil)
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`)
if err != nil {
t.Fatal(err)
}
reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600))
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
nil, 1, math.Pi, reference,
)
if err != nil {
t.Fatal(err)
}
_, err = db.Exec(
`INSERT INTO test (col) VALUES (?), (?), (?), (?)`,
sqlite3.JSON(math.Pi), sqlite3.JSON(false),
julianday.Format(reference), sqlite3.JSON([]string{}))
if err != nil {
t.Fatal(err)
}
rows, err := db.Query("SELECT * FROM test")
if err != nil {
t.Fatal(err)
}
want := []string{
"null", "1", "3.141592653589793",
`"2013-10-07T04:23:19.12-04:00"`,
"3.141592653589793", "false",
"2456572.849526851851852", "[]",
}
for rows.Next() {
var got json.RawMessage
err = rows.Scan(sqlite3.JSON(&got))
if err != nil {
t.Fatal(err)
}
if string(got) != want[0] {
t.Errorf("got %q, want %q", got, want[0])
}
want = want[1:]
}
}

View File

@@ -339,8 +339,9 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) {
return t, nil return t, nil
} }
// Scanner returns a [database/sql.Scanner] that // Scanner returns a [database/sql.Scanner] that can be used as an argument to
// decodes a time value into dest using this format. // [database/sql.Row.Scan] and similar methods to
// decode a time value into dest using this format.
func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } { func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } {
return timeScanner{dest, f} return timeScanner{dest, f}
} }