From d60fceac9235cd77d1033b2f9972b34a08e40b8d Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 18 Oct 2023 23:14:46 +0100 Subject: [PATCH] JSON support. --- README.md | 2 +- driver/driver.go | 9 +++++- json.go | 53 ++++++++++++++++++++++++++++++++++++ stmt.go | 2 +- tests/json_test.go | 68 ++++++++++++++++++++++++++++++++++++++++++++++ time.go | 5 ++-- 6 files changed, 134 insertions(+), 5 deletions(-) create mode 100644 json.go create mode 100644 tests/json_test.go diff --git a/README.md b/README.md index 8e0adfe..9f6ebdf 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,7 @@ Performance is tested by running - [x] nested transactions - [x] incremental BLOB I/O - [x] online backup - - [ ] JSON support + - [x] JSON support - [ ] session extension - [ ] custom VFSes - [x] custom VFS API diff --git a/driver/driver.go b/driver/driver.go index 825c5b0..6bd4bb9 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -30,6 +30,7 @@ import ( "context" "database/sql" "database/sql/driver" + "encoding/json" "fmt" "io" "net/url" @@ -272,6 +273,10 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name return newResult(c.Conn), nil } +func (*conn) CheckNamedValue(arg *driver.NamedValue) error { + return nil +} + type stmt struct { Stmt *sqlite3.Stmt Conn *sqlite3.Conn @@ -370,6 +375,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error { err = s.Stmt.BindZeroBlob(id, int64(a)) case time.Time: err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault) + case json.Marshaler: + err = s.Stmt.BindJSON(id, a) case nil: err = s.Stmt.BindNull(id) default: @@ -386,7 +393,7 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error { func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error { switch arg.Value.(type) { case bool, int, int64, float64, string, []byte, - sqlite3.ZeroBlob, time.Time, nil: + sqlite3.ZeroBlob, time.Time, json.Marshaler, nil: return nil default: return driver.ErrSkip diff --git a/json.go b/json.go new file mode 100644 index 0000000..1bd7a83 --- /dev/null +++ b/json.go @@ -0,0 +1,53 @@ +package sqlite3 + +import ( + "encoding/json" + "strconv" + "time" + "unsafe" +) + +// JSON returns: +// a [json.Marshaler] that can be used as an argument to +// [database/sql.DB.Exec] and similar methods to +// store value as JSON; and +// a [database/sql.Scanner] that can be used as an argument to +// [database/sql.Row.Scan] and similar methods to +// decode JSON into value. +func JSON(value any) any { + return jsonValue{value} +} + +type jsonValue struct{ any } + +func (j jsonValue) MarshalJSON() ([]byte, error) { + return json.Marshal(j.any) +} + +func (j jsonValue) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, j.any) +} + +func (j jsonValue) Scan(value any) error { + var mem [40]byte + buf := mem[:0] + + switch v := value.(type) { + case []byte: + buf = v + case string: + buf = unsafe.Slice(unsafe.StringData(v), len(v)) + case int64: + buf = strconv.AppendInt(nil, v, 10) + case float64: + buf = strconv.AppendFloat(nil, v, 'g', -1, 64) + case time.Time: + buf = append(buf, '"') + buf = v.AppendFormat(buf, time.RFC3339Nano) + buf = append(buf, '"') + case nil: + buf = append(buf, "null"...) + } + + return j.UnmarshalJSON(buf) +} diff --git a/stmt.go b/stmt.go index 23b9b41..34b638a 100644 --- a/stmt.go +++ b/stmt.go @@ -263,7 +263,7 @@ func (s *Stmt) BindJSON(param int, value any) error { r := s.c.call(s.c.api.bindText, uint64(s.handle), uint64(param), uint64(ptr), uint64(len(data)), - uint64(s.c.api.destructor)) + uint64(s.c.api.destructor), _UTF8) return s.c.error(r) } diff --git a/tests/json_test.go b/tests/json_test.go new file mode 100644 index 0000000..70ffd28 --- /dev/null +++ b/tests/json_test.go @@ -0,0 +1,68 @@ +package tests + +import ( + "encoding/json" + "math" + "testing" + "time" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/driver" + "github.com/ncruces/julianday" +) + +func TestJSON(t *testing.T) { + t.Parallel() + + db, err := driver.Open(":memory:", nil) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) + + _, err = db.Exec( + `INSERT INTO test (col) VALUES (?), (?), (?), (?)`, + nil, 1, math.Pi, reference, + ) + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec( + `INSERT INTO test (col) VALUES (?), (?), (?), (?)`, + sqlite3.JSON(math.Pi), sqlite3.JSON(false), + julianday.Format(reference), sqlite3.JSON([]string{})) + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query("SELECT * FROM test") + if err != nil { + t.Fatal(err) + } + + want := []string{ + "null", "1", "3.141592653589793", + `"2013-10-07T04:23:19.12-04:00"`, + "3.141592653589793", "false", + "2456572.849526851851852", "[]", + } + for rows.Next() { + var got json.RawMessage + err = rows.Scan(sqlite3.JSON(&got)) + if err != nil { + t.Fatal(err) + } + if string(got) != want[0] { + t.Errorf("got %q, want %q", got, want[0]) + } + want = want[1:] + } +} diff --git a/time.go b/time.go index 5449bec..8475439 100644 --- a/time.go +++ b/time.go @@ -339,8 +339,9 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) { return t, nil } -// Scanner returns a [database/sql.Scanner] that -// decodes a time value into dest using this format. +// Scanner returns a [database/sql.Scanner] that can be used as an argument to +// [database/sql.Row.Scan] and similar methods to +// decode a time value into dest using this format. func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } { return timeScanner{dest, f} }