diff --git a/driver/driver.go b/driver/driver.go index c1d96bb..6408e68 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -417,10 +417,10 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error { err = s.Stmt.BindZeroBlob(id, int64(a)) case time.Time: err = s.Stmt.BindTime(id, a, s.tmWrite) - case interface{ Pointer() any }: - err = s.Stmt.BindPointer(id, a.Pointer()) - case interface{ JSON() any }: - err = s.Stmt.BindJSON(id, a.JSON()) + case util.JSON: + err = s.Stmt.BindJSON(id, a.Value) + case util.PointerUnwrap: + err = s.Stmt.BindPointer(id, util.UnwrapPointer(a)) case nil: err = s.Stmt.BindNull(id) default: @@ -437,9 +437,8 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error { func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error { switch arg.Value.(type) { case bool, int, int64, float64, string, []byte, - sqlite3.ZeroBlob, time.Time, - interface{ Pointer() any }, - interface{ JSON() any }, + time.Time, sqlite3.ZeroBlob, + util.JSON, util.PointerUnwrap, nil: return nil default: diff --git a/internal/util/json.go b/internal/util/json.go new file mode 100644 index 0000000..c0ba38c --- /dev/null +++ b/internal/util/json.go @@ -0,0 +1,35 @@ +package util + +import ( + "encoding/json" + "strconv" + "time" + "unsafe" +) + +type JSON struct{ Value any } + +func (j JSON) Scan(value any) error { + var buf []byte + + switch v := value.(type) { + case []byte: + buf = v + case string: + buf = unsafe.Slice(unsafe.StringData(v), len(v)) + case int64: + buf = strconv.AppendInt(nil, v, 10) + case float64: + buf = strconv.AppendFloat(nil, v, 'g', -1, 64) + case time.Time: + buf = append(buf, '"') + buf = v.AppendFormat(buf, time.RFC3339Nano) + buf = append(buf, '"') + case nil: + buf = append(buf, "null"...) + default: + panic(AssertErr()) + } + + return json.Unmarshal(buf, j.Value) +} diff --git a/internal/util/pointer.go b/internal/util/pointer.go new file mode 100644 index 0000000..eae4dae --- /dev/null +++ b/internal/util/pointer.go @@ -0,0 +1,11 @@ +package util + +type Pointer[T any] struct{ Value T } + +func (p Pointer[T]) unwrap() any { return p.Value } + +type PointerUnwrap interface{ unwrap() any } + +func UnwrapPointer(p PointerUnwrap) any { + return p.unwrap() +} diff --git a/json.go b/json.go index 37039fc..9b2565e 100644 --- a/json.go +++ b/json.go @@ -1,47 +1,11 @@ package sqlite3 -import ( - "encoding/json" - "strconv" - "time" - "unsafe" - - "github.com/ncruces/go-sqlite3/internal/util" -) +import "github.com/ncruces/go-sqlite3/internal/util" // JSON returns a value that can be used as an argument to // [database/sql.DB.Exec], [database/sql.Row.Scan] and similar methods to // store value as JSON, or decode JSON into value. // JSON should NOT be used with [BindJSON] or [ResultJSON]. func JSON(value any) any { - return jsonValue{value} -} - -type jsonValue struct{ any } - -func (j jsonValue) JSON() any { return j.any } - -func (j jsonValue) Scan(value any) error { - var buf []byte - - switch v := value.(type) { - case []byte: - buf = v - case string: - buf = unsafe.Slice(unsafe.StringData(v), len(v)) - case int64: - buf = strconv.AppendInt(nil, v, 10) - case float64: - buf = strconv.AppendFloat(nil, v, 'g', -1, 64) - case time.Time: - buf = append(buf, '"') - buf = v.AppendFormat(buf, time.RFC3339Nano) - buf = append(buf, '"') - case nil: - buf = append(buf, "null"...) - default: - panic(util.AssertErr()) - } - - return json.Unmarshal(buf, j.any) + return util.JSON{Value: value} } diff --git a/pointer.go b/pointer.go index 9647f67..611c152 100644 --- a/pointer.go +++ b/pointer.go @@ -1,14 +1,12 @@ package sqlite3 +import "github.com/ncruces/go-sqlite3/internal/util" + // Pointer returns a pointer to a value that can be used as an argument to // [database/sql.DB.Exec] and similar methods. // Pointer should NOT be used with [BindPointer] or [ResultPointer]. // // https://sqlite.org/bindptr.html -func Pointer[T any](val T) any { - return pointer[T]{val} +func Pointer[T any](value T) any { + return util.Pointer[T]{Value: value} } - -type pointer[T any] struct{ val T } - -func (p pointer[T]) Pointer() any { return p.val }