diff --git a/driver/driver.go b/driver/driver.go index 871aa74..9250cf3 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -20,22 +20,45 @@ // - a [serializable] transaction is always "immediate"; // - a [read-only] transaction is always "deferred". // +// # Datatypes In SQLite +// +// SQLite is dynamically typed. +// Columns can mostly hold any value regardless of their declared type. +// SQLite supports most [driver.Value] types out of the box, +// but bool and [time.Time] require special care. +// +// Booleans can be stored on any column type and scanned back to a *bool. +// However, if scanned to a *any, booleans may either become an +// int64, string or bool, depending on the declared type of the column. +// If you use BOOLEAN for your column type, +// 1 and 0 will always scan as true and false. +// // # Working with time // +// Time values can similarly be stored on any column type. // The time encoding/decoding format can be specified using "_timefmt": // // sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite") // -// Possible values are: "auto" (the default), "sqlite", "rfc3339"; +// Special values are: "auto" (the default), "sqlite", "rfc3339"; // - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite; // - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite; // - "rfc3339" encodes and decodes RFC 3339 only. // -// If you encode as RFC 3339 (the default), -// consider using the TIME [collating sequence] to produce a time-ordered sequence. +// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout]. // -// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful. -// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding. +// If you encode as RFC 3339 (the default), +// consider using the TIME [collating sequence] to produce time-ordered sequences. +// +// If you encode as RFC 3339 (the default), +// time values will scan back to a *time.Time unless your column type is TEXT. +// Otherwise, if scanned to a *any, time values may either become an +// int64, float64 or string, depending on the time format and declared type of the column. +// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type, +// "_timefmt" will be used to decode values. +// +// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful. +// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding. // // When using a custom time struct, you'll have to implement // [database/sql/driver.Valuer] and [database/sql.Scanner]. @@ -48,7 +71,7 @@ // The Scan method needs to take into account that the value it receives can be of differing types. // It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules. // Or it can be a: string, int64, float64, []byte, or nil, -// depending on the column type and what whoever wrote the value. +// depending on the column type and whoever wrote the value. // [sqlite3.TimeFormat.Decode] may help. // // # Setting PRAGMAs @@ -595,6 +618,28 @@ const ( _TIME ) +func scanFromDecl(decl string) scantype { + // These types are only used before we have rows, + // and otherwise as type hints. + // The first few ensure STRICT tables are strictly typed. + // The other two are type hints for booleans and time. + switch decl { + case "INT", "INTEGER": + return _INT + case "REAL": + return _REAL + case "TEXT": + return _TEXT + case "BLOB": + return _BLOB + case "BOOLEAN": + return _BOOL + case "DATE", "TIME", "DATETIME", "TIMESTAMP": + return _TIME + } + return _ANY +} + var ( // Ensure these interfaces are implemented: _ driver.RowsColumnTypeDatabaseTypeName = &rows{} @@ -619,6 +664,18 @@ func (r *rows) Columns() []string { return r.names } +func (r *rows) scanType(index int) scantype { + if r.scans == nil { + count := r.Stmt.ColumnCount() + scans := make([]scantype, count) + for i := range scans { + scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i))) + } + r.scans = scans + } + return r.scans[index] +} + func (r *rows) loadColumnMetadata() { if r.nulls == nil { count := r.Stmt.ColumnCount() @@ -632,24 +689,7 @@ func (r *rows) loadColumnMetadata() { r.Stmt.ColumnTableName(i), col) types[i] = strings.ToUpper(types[i]) - // These types are only used before we have rows, - // and otherwise as type hints. - // The first few ensure STRICT tables are strictly typed. - // The other two are type hints for booleans and time. - switch types[i] { - case "INT", "INTEGER": - scans[i] = _INT - case "REAL": - scans[i] = _REAL - case "TEXT": - scans[i] = _TEXT - case "BLOB": - scans[i] = _BLOB - case "BOOLEAN": - scans[i] = _BOOL - case "DATE", "TIME", "DATETIME", "TIMESTAMP": - scans[i] = _TIME - } + scans[i] = scanFromDecl(types[i]) } } r.nulls = nulls @@ -658,27 +698,15 @@ func (r *rows) loadColumnMetadata() { } } -func (r *rows) declType(index int) string { - if r.types == nil { - count := r.Stmt.ColumnCount() - types := make([]string, count) - for i := range types { - types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i)) - } - r.types = types - } - return r.types[index] -} - func (r *rows) ColumnTypeDatabaseTypeName(index int) string { r.loadColumnMetadata() - decltype := r.types[index] - if len := len(decltype); len > 0 && decltype[len-1] == ')' { - if i := strings.LastIndexByte(decltype, '('); i >= 0 { - decltype = decltype[:i] + decl := r.types[index] + if len := len(decl); len > 0 && decl[len-1] == ')' { + if i := strings.LastIndexByte(decl, '('); i >= 0 { + decl = decl[:i] } } - return strings.TrimSpace(decltype) + return strings.TrimSpace(decl) } func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { @@ -745,36 +773,49 @@ func (r *rows) Next(dest []driver.Value) error { } data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest)) - err := r.Stmt.Columns(data...) + if err := r.Stmt.ColumnsRaw(data...); err != nil { + return err + } for i := range dest { - if t, ok := r.decodeTime(i, dest[i]); ok { - dest[i] = t - } - } - return err -} - -func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) { - switch v := v.(type) { - case int64, float64: - // could be a time value - case string: - if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano { + scan := r.scanType(i) + switch v := dest[i].(type) { + case int64: + if scan == _BOOL { + switch v { + case 1: + dest[i] = true + case 0: + dest[i] = false + } + continue + } + case []byte: + if len(v) == cap(v) { // a BLOB + continue + } + if scan != _TEXT { + switch r.tmWrite { + case "", time.RFC3339, time.RFC3339Nano: + t, ok := maybeTime(v) + if ok { + dest[i] = t + continue + } + } + } + dest[i] = string(v) + case float64: break + default: + continue } - t, ok := maybeTime(v) - if ok { - return t, true + if scan == _TIME { + t, err := r.tmRead.Decode(dest[i]) + if err == nil { + dest[i] = t + continue + } } - default: - return } - switch r.declType(i) { - case "DATE", "TIME", "DATETIME", "TIMESTAMP": - // could be a time value - default: - return - } - t, err := r.tmRead.Decode(v) - return t, err == nil + return nil } diff --git a/driver/example2_test.go b/driver/example2_test.go index 7e70530..d955963 100644 --- a/driver/example2_test.go +++ b/driver/example2_test.go @@ -1,9 +1,5 @@ -//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk - package driver_test -// Adapted from: https://go.dev/doc/tutorial/database-access - import ( "database/sql" "database/sql/driver" @@ -27,7 +23,7 @@ func Example_customTime() { _, err = db.Exec(` CREATE TABLE data ( id INTEGER PRIMARY KEY, - date_time TEXT + date_time ANY ) STRICT; `) if err != nil { diff --git a/driver/time.go b/driver/time.go index b3ebdd2..4d48bd8 100644 --- a/driver/time.go +++ b/driver/time.go @@ -1,12 +1,15 @@ package driver -import "time" +import ( + "bytes" + "time" +) // Convert a string in [time.RFC3339Nano] format into a [time.Time] // if it roundtrips back to the same string. // This way times can be persisted to, and recovered from, the database, // but if a string is needed, [database/sql] will recover the same string. -func maybeTime(text string) (_ time.Time, _ bool) { +func maybeTime(text []byte) (_ time.Time, _ bool) { // Weed out (some) values that can't possibly be // [time.RFC3339Nano] timestamps. if len(text) < len("2006-01-02T15:04:05Z") { @@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) { // Slow path. var buf [len(time.RFC3339Nano)]byte - date, err := time.Parse(time.RFC3339Nano, text) - if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) { + date, err := time.Parse(time.RFC3339Nano, string(text)) + if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) { return date, true } return diff --git a/driver/time_test.go b/driver/time_test.go index 0b56ba8..7a9ed6a 100644 --- a/driver/time_test.go +++ b/driver/time_test.go @@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) { f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") f.Fuzz(func(t *testing.T, str string) { - v, ok := maybeTime(str) + v, ok := maybeTime([]byte(str)) if ok { // Make sure times round-trip to the same string: // https://pkg.go.dev/database/sql#Rows.Scan @@ -51,7 +51,7 @@ func Fuzz_stringOrTime_2(f *testing.F) { f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC checkTime := func(t testing.TB, date time.Time) { - v, ok := maybeTime(date.Format(time.RFC3339Nano)) + v, ok := maybeTime(date.AppendFormat(nil, time.RFC3339Nano)) if ok { // Make sure times round-trip to the same time: if !v.Equal(date) { diff --git a/stmt.go b/stmt.go index 8713f9d..c176102 100644 --- a/stmt.go +++ b/stmt.go @@ -571,7 +571,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { func (s *Stmt) ColumnRawText(col int) []byte { ptr := ptr_t(s.c.call("sqlite3_column_text", stk_t(s.handle), stk_t(col))) - return s.columnRawBytes(col, ptr) + return s.columnRawBytes(col, ptr, 1) } // ColumnRawBlob returns the value of the result column as a []byte. @@ -583,10 +583,10 @@ func (s *Stmt) ColumnRawText(col int) []byte { func (s *Stmt) ColumnRawBlob(col int) []byte { ptr := ptr_t(s.c.call("sqlite3_column_blob", stk_t(s.handle), stk_t(col))) - return s.columnRawBytes(col, ptr) + return s.columnRawBytes(col, ptr, 0) } -func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { +func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte { if ptr == 0 { rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle))) if rc != _ROW && rc != _DONE { @@ -597,7 +597,7 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { n := int32(s.c.call("sqlite3_column_bytes", stk_t(s.handle), stk_t(col))) - return util.View(s.c.mod, ptr, int64(n)) + return util.View(s.c.mod, ptr, int64(n+nul))[:n] } // ColumnJSON parses the JSON-encoded value of the result column @@ -644,22 +644,12 @@ func (s *Stmt) ColumnValue(col int) Value { // [INTEGER] columns will be retrieved as int64 values, // [FLOAT] as float64, [NULL] as nil, // [TEXT] as string, and [BLOB] as []byte. -// Any []byte are owned by SQLite and may be invalidated by -// subsequent calls to [Stmt] methods. func (s *Stmt) Columns(dest ...any) error { - defer s.c.arena.mark()() - count := int64(len(dest)) - typePtr := s.c.arena.new(count) - dataPtr := s.c.arena.new(count * 8) - - rc := res_t(s.c.call("sqlite3_columns_go", - stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr))) - if err := s.c.error(rc); err != nil { + types, ptr, err := s.columns(int64(len(dest))) + if err != nil { return err } - types := util.View(s.c.mod, typePtr, count) - // Avoid bounds checks on types below. if len(types) != len(dest) { panic(util.AssertErr()) @@ -668,30 +658,95 @@ func (s *Stmt) Columns(dest ...any) error { for i := range dest { switch types[i] { case byte(INTEGER): - dest[i] = util.Read64[int64](s.c.mod, dataPtr) + dest[i] = util.Read64[int64](s.c.mod, ptr) case byte(FLOAT): - dest[i] = util.ReadFloat64(s.c.mod, dataPtr) + dest[i] = util.ReadFloat64(s.c.mod, ptr) case byte(NULL): dest[i] = nil - default: - len := util.Read32[int32](s.c.mod, dataPtr+4) + case byte(TEXT): + len := util.Read32[int32](s.c.mod, ptr+4) if len != 0 { - ptr := util.Read32[ptr_t](s.c.mod, dataPtr) + ptr := util.Read32[ptr_t](s.c.mod, ptr) buf := util.View(s.c.mod, ptr, int64(len)) - if types[i] == byte(TEXT) { - dest[i] = string(buf) - } else { - dest[i] = buf - } + dest[i] = string(buf) } else { - if types[i] == byte(TEXT) { - dest[i] = "" - } else { - dest[i] = []byte{} - } + dest[i] = "" + } + case byte(BLOB): + len := util.Read32[int32](s.c.mod, ptr+4) + if len != 0 { + ptr := util.Read32[ptr_t](s.c.mod, ptr) + buf := util.View(s.c.mod, ptr, int64(len)) + tmp, _ := dest[i].([]byte) + dest[i] = append(tmp[:0], buf...) + } else { + dest[i], _ = dest[i].([]byte) } } - dataPtr += 8 + ptr += 8 } return nil } + +// ColumnsRaw populates result columns into the provided slice. +// The slice must have [Stmt.ColumnCount] length. +// +// [INTEGER] columns will be retrieved as int64 values, +// [FLOAT] as float64, [NULL] as nil, +// [TEXT] and [BLOB] as []byte. +// Any []byte are owned by SQLite and may be invalidated by +// subsequent calls to [Stmt] methods. +func (s *Stmt) ColumnsRaw(dest ...any) error { + types, ptr, err := s.columns(int64(len(dest))) + if err != nil { + return err + } + + // Avoid bounds checks on types below. + if len(types) != len(dest) { + panic(util.AssertErr()) + } + + for i := range dest { + switch types[i] { + case byte(INTEGER): + dest[i] = util.Read64[int64](s.c.mod, ptr) + case byte(FLOAT): + dest[i] = util.ReadFloat64(s.c.mod, ptr) + case byte(NULL): + dest[i] = nil + default: + len := util.Read32[int32](s.c.mod, ptr+4) + if len == 0 && types[i] == byte(BLOB) { + dest[i] = []byte{} + } else { + cap := len + if types[i] == byte(TEXT) { + cap++ + } + ptr := util.Read32[ptr_t](s.c.mod, ptr) + buf := util.View(s.c.mod, ptr, int64(cap))[:len] + dest[i] = buf + } + } + ptr += 8 + } + return nil +} + +func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) { + defer s.c.arena.mark()() + typePtr := s.c.arena.new(count) + dataPtr := s.c.arena.new(count * 8) + + rc := res_t(s.c.call("sqlite3_columns_go", + stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr))) + if rc == res_t(MISUSE) { + return nil, 0, MISUSE + } + if err := s.c.error(rc); err != nil { + return nil, 0, err + } + + return util.View(s.c.mod, typePtr, count), dataPtr, nil +} diff --git a/tests/stmt_test.go b/tests/stmt_test.go index 0a3fe45..d057405 100644 --- a/tests/stmt_test.go +++ b/tests/stmt_test.go @@ -664,6 +664,42 @@ func TestStmt_ColumnValue(t *testing.T) { } } +func TestStmt_Columns(t *testing.T) { + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare(`SELECT 0, 0.5, 'abc', x'cafe', NULL`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + if stmt.Step() { + var dest [5]any + if err := stmt.Columns(dest[:]...); err != nil { + t.Fatal(err) + } + if got := dest[0]; got != int64(0) { + t.Errorf("got %d, want 0", got) + } + if got := dest[1]; got != float64(0.5) { + t.Errorf("got %f, want 0.5", got) + } + if got := dest[2]; got != "abc" { + t.Errorf("got %q, want 'abc'", got) + } + if got := dest[3]; string(got.([]byte)) != "\xCA\xFE" { + t.Errorf("got %q, want x'cafe'", got) + } + if got := dest[4]; got != nil { + t.Errorf("got %q, want nil", got) + } + } +} + func TestStmt_Error(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") diff --git a/txn.go b/txn.go index 7a5e112..931b899 100644 --- a/txn.go +++ b/txn.go @@ -20,7 +20,7 @@ type Txn struct { } // Begin starts a deferred transaction. -// Panics if a transaction is already in-progress. +// It panics if a transaction is in-progress. // For nested transactions, use [Conn.Savepoint]. // // https://sqlite.org/lang_transaction.html diff --git a/value.go b/value.go index a2399fb..6753027 100644 --- a/value.go +++ b/value.go @@ -139,7 +139,7 @@ func (v Value) Blob(buf []byte) []byte { // https://sqlite.org/c3ref/value_blob.html func (v Value) RawText() []byte { ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected())) - return v.rawBytes(ptr) + return v.rawBytes(ptr, 1) } // RawBlob returns the value as a []byte. @@ -149,16 +149,16 @@ func (v Value) RawText() []byte { // https://sqlite.org/c3ref/value_blob.html func (v Value) RawBlob() []byte { ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected())) - return v.rawBytes(ptr) + return v.rawBytes(ptr, 0) } -func (v Value) rawBytes(ptr ptr_t) []byte { +func (v Value) rawBytes(ptr ptr_t, nul int32) []byte { if ptr == 0 { return nil } n := int32(v.c.call("sqlite3_value_bytes", v.protected())) - return util.View(v.c.mod, ptr, int64(n)) + return util.View(v.c.mod, ptr, int64(n+nul))[:n] } // Pointer gets the pointer associated with this value,