Avoid escaping times (#256)

This commit is contained in:
Nuno Cruces
2025-03-31 13:02:41 +01:00
committed by GitHub
parent 41dc46af7e
commit 1f5d8bf7df
8 changed files with 247 additions and 116 deletions

View File

@@ -20,22 +20,45 @@
// - a [serializable] transaction is always "immediate"; // - a [serializable] transaction is always "immediate";
// - a [read-only] transaction is always "deferred". // - a [read-only] transaction is always "deferred".
// //
// # Datatypes In SQLite
//
// SQLite is dynamically typed.
// Columns can mostly hold any value regardless of their declared type.
// SQLite supports most [driver.Value] types out of the box,
// but bool and [time.Time] require special care.
//
// Booleans can be stored on any column type and scanned back to a *bool.
// However, if scanned to a *any, booleans may either become an
// int64, string or bool, depending on the declared type of the column.
// If you use BOOLEAN for your column type,
// 1 and 0 will always scan as true and false.
//
// # Working with time // # Working with time
// //
// Time values can similarly be stored on any column type.
// The time encoding/decoding format can be specified using "_timefmt": // The time encoding/decoding format can be specified using "_timefmt":
// //
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite") // sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
// //
// Possible values are: "auto" (the default), "sqlite", "rfc3339"; // Special values are: "auto" (the default), "sqlite", "rfc3339";
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite; // - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite; // - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
// - "rfc3339" encodes and decodes RFC 3339 only. // - "rfc3339" encodes and decodes RFC 3339 only.
// //
// If you encode as RFC 3339 (the default), // You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout].
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
// //
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful. // If you encode as RFC 3339 (the default),
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding. // consider using the TIME [collating sequence] to produce time-ordered sequences.
//
// If you encode as RFC 3339 (the default),
// time values will scan back to a *time.Time unless your column type is TEXT.
// Otherwise, if scanned to a *any, time values may either become an
// int64, float64 or string, depending on the time format and declared type of the column.
// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type,
// "_timefmt" will be used to decode values.
//
// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful.
// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding.
// //
// When using a custom time struct, you'll have to implement // When using a custom time struct, you'll have to implement
// [database/sql/driver.Valuer] and [database/sql.Scanner]. // [database/sql/driver.Valuer] and [database/sql.Scanner].
@@ -48,7 +71,7 @@
// The Scan method needs to take into account that the value it receives can be of differing types. // The Scan method needs to take into account that the value it receives can be of differing types.
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules. // It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
// Or it can be a: string, int64, float64, []byte, or nil, // Or it can be a: string, int64, float64, []byte, or nil,
// depending on the column type and what whoever wrote the value. // depending on the column type and whoever wrote the value.
// [sqlite3.TimeFormat.Decode] may help. // [sqlite3.TimeFormat.Decode] may help.
// //
// # Setting PRAGMAs // # Setting PRAGMAs
@@ -595,6 +618,28 @@ const (
_TIME _TIME
) )
func scanFromDecl(decl string) scantype {
// These types are only used before we have rows,
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch decl {
case "INT", "INTEGER":
return _INT
case "REAL":
return _REAL
case "TEXT":
return _TEXT
case "BLOB":
return _BLOB
case "BOOLEAN":
return _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
return _TIME
}
return _ANY
}
var ( var (
// Ensure these interfaces are implemented: // Ensure these interfaces are implemented:
_ driver.RowsColumnTypeDatabaseTypeName = &rows{} _ driver.RowsColumnTypeDatabaseTypeName = &rows{}
@@ -619,6 +664,18 @@ func (r *rows) Columns() []string {
return r.names return r.names
} }
func (r *rows) scanType(index int) scantype {
if r.scans == nil {
count := r.Stmt.ColumnCount()
scans := make([]scantype, count)
for i := range scans {
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
}
r.scans = scans
}
return r.scans[index]
}
func (r *rows) loadColumnMetadata() { func (r *rows) loadColumnMetadata() {
if r.nulls == nil { if r.nulls == nil {
count := r.Stmt.ColumnCount() count := r.Stmt.ColumnCount()
@@ -632,24 +689,7 @@ func (r *rows) loadColumnMetadata() {
r.Stmt.ColumnTableName(i), r.Stmt.ColumnTableName(i),
col) col)
types[i] = strings.ToUpper(types[i]) types[i] = strings.ToUpper(types[i])
// These types are only used before we have rows, scans[i] = scanFromDecl(types[i])
// and otherwise as type hints.
// The first few ensure STRICT tables are strictly typed.
// The other two are type hints for booleans and time.
switch types[i] {
case "INT", "INTEGER":
scans[i] = _INT
case "REAL":
scans[i] = _REAL
case "TEXT":
scans[i] = _TEXT
case "BLOB":
scans[i] = _BLOB
case "BOOLEAN":
scans[i] = _BOOL
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
scans[i] = _TIME
}
} }
} }
r.nulls = nulls r.nulls = nulls
@@ -658,27 +698,15 @@ func (r *rows) loadColumnMetadata() {
} }
} }
func (r *rows) declType(index int) string {
if r.types == nil {
count := r.Stmt.ColumnCount()
types := make([]string, count)
for i := range types {
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
}
r.types = types
}
return r.types[index]
}
func (r *rows) ColumnTypeDatabaseTypeName(index int) string { func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
r.loadColumnMetadata() r.loadColumnMetadata()
decltype := r.types[index] decl := r.types[index]
if len := len(decltype); len > 0 && decltype[len-1] == ')' { if len := len(decl); len > 0 && decl[len-1] == ')' {
if i := strings.LastIndexByte(decltype, '('); i >= 0 { if i := strings.LastIndexByte(decl, '('); i >= 0 {
decltype = decltype[:i] decl = decl[:i]
} }
} }
return strings.TrimSpace(decltype) return strings.TrimSpace(decl)
} }
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
@@ -745,36 +773,49 @@ func (r *rows) Next(dest []driver.Value) error {
} }
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest)) data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
err := r.Stmt.Columns(data...) if err := r.Stmt.ColumnsRaw(data...); err != nil {
return err
}
for i := range dest { for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok { scan := r.scanType(i)
dest[i] = t switch v := dest[i].(type) {
} case int64:
} if scan == _BOOL {
return err switch v {
} case 1:
dest[i] = true
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) { case 0:
switch v := v.(type) { dest[i] = false
case int64, float64: }
// could be a time value continue
case string: }
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano { case []byte:
if len(v) == cap(v) { // a BLOB
continue
}
if scan != _TEXT {
switch r.tmWrite {
case "", time.RFC3339, time.RFC3339Nano:
t, ok := maybeTime(v)
if ok {
dest[i] = t
continue
}
}
}
dest[i] = string(v)
case float64:
break break
default:
continue
} }
t, ok := maybeTime(v) if scan == _TIME {
if ok { t, err := r.tmRead.Decode(dest[i])
return t, true if err == nil {
dest[i] = t
continue
}
} }
default:
return
} }
switch r.declType(i) { return nil
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
// could be a time value
default:
return
}
t, err := r.tmRead.Decode(v)
return t, err == nil
} }

View File

@@ -1,9 +1,5 @@
//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk
package driver_test package driver_test
// Adapted from: https://go.dev/doc/tutorial/database-access
import ( import (
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
@@ -27,7 +23,7 @@ func Example_customTime() {
_, err = db.Exec(` _, err = db.Exec(`
CREATE TABLE data ( CREATE TABLE data (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
date_time TEXT date_time ANY
) STRICT; ) STRICT;
`) `)
if err != nil { if err != nil {

View File

@@ -1,12 +1,15 @@
package driver package driver
import "time" import (
"bytes"
"time"
)
// Convert a string in [time.RFC3339Nano] format into a [time.Time] // Convert a string in [time.RFC3339Nano] format into a [time.Time]
// if it roundtrips back to the same string. // if it roundtrips back to the same string.
// This way times can be persisted to, and recovered from, the database, // This way times can be persisted to, and recovered from, the database,
// but if a string is needed, [database/sql] will recover the same string. // but if a string is needed, [database/sql] will recover the same string.
func maybeTime(text string) (_ time.Time, _ bool) { func maybeTime(text []byte) (_ time.Time, _ bool) {
// Weed out (some) values that can't possibly be // Weed out (some) values that can't possibly be
// [time.RFC3339Nano] timestamps. // [time.RFC3339Nano] timestamps.
if len(text) < len("2006-01-02T15:04:05Z") { if len(text) < len("2006-01-02T15:04:05Z") {
@@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) {
// Slow path. // Slow path.
var buf [len(time.RFC3339Nano)]byte var buf [len(time.RFC3339Nano)]byte
date, err := time.Parse(time.RFC3339Nano, text) date, err := time.Parse(time.RFC3339Nano, string(text))
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) { if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) {
return date, true return date, true
} }
return return

View File

@@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) {
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
f.Fuzz(func(t *testing.T, str string) { f.Fuzz(func(t *testing.T, str string) {
v, ok := maybeTime(str) v, ok := maybeTime([]byte(str))
if ok { if ok {
// Make sure times round-trip to the same string: // Make sure times round-trip to the same string:
// https://pkg.go.dev/database/sql#Rows.Scan // https://pkg.go.dev/database/sql#Rows.Scan
@@ -51,7 +51,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
checkTime := func(t testing.TB, date time.Time) { checkTime := func(t testing.TB, date time.Time) {
v, ok := maybeTime(date.Format(time.RFC3339Nano)) v, ok := maybeTime(date.AppendFormat(nil, time.RFC3339Nano))
if ok { if ok {
// Make sure times round-trip to the same time: // Make sure times round-trip to the same time:
if !v.Equal(date) { if !v.Equal(date) {

119
stmt.go
View File

@@ -571,7 +571,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
func (s *Stmt) ColumnRawText(col int) []byte { func (s *Stmt) ColumnRawText(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_text", ptr := ptr_t(s.c.call("sqlite3_column_text",
stk_t(s.handle), stk_t(col))) stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr) return s.columnRawBytes(col, ptr, 1)
} }
// ColumnRawBlob returns the value of the result column as a []byte. // ColumnRawBlob returns the value of the result column as a []byte.
@@ -583,10 +583,10 @@ func (s *Stmt) ColumnRawText(col int) []byte {
func (s *Stmt) ColumnRawBlob(col int) []byte { func (s *Stmt) ColumnRawBlob(col int) []byte {
ptr := ptr_t(s.c.call("sqlite3_column_blob", ptr := ptr_t(s.c.call("sqlite3_column_blob",
stk_t(s.handle), stk_t(col))) stk_t(s.handle), stk_t(col)))
return s.columnRawBytes(col, ptr) return s.columnRawBytes(col, ptr, 0)
} }
func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte { func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte {
if ptr == 0 { if ptr == 0 {
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle))) rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
if rc != _ROW && rc != _DONE { if rc != _ROW && rc != _DONE {
@@ -597,7 +597,7 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
n := int32(s.c.call("sqlite3_column_bytes", n := int32(s.c.call("sqlite3_column_bytes",
stk_t(s.handle), stk_t(col))) stk_t(s.handle), stk_t(col)))
return util.View(s.c.mod, ptr, int64(n)) return util.View(s.c.mod, ptr, int64(n+nul))[:n]
} }
// ColumnJSON parses the JSON-encoded value of the result column // ColumnJSON parses the JSON-encoded value of the result column
@@ -644,22 +644,12 @@ func (s *Stmt) ColumnValue(col int) Value {
// [INTEGER] columns will be retrieved as int64 values, // [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil, // [FLOAT] as float64, [NULL] as nil,
// [TEXT] as string, and [BLOB] as []byte. // [TEXT] as string, and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) Columns(dest ...any) error { func (s *Stmt) Columns(dest ...any) error {
defer s.c.arena.mark()() types, ptr, err := s.columns(int64(len(dest)))
count := int64(len(dest)) if err != nil {
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if err := s.c.error(rc); err != nil {
return err return err
} }
types := util.View(s.c.mod, typePtr, count)
// Avoid bounds checks on types below. // Avoid bounds checks on types below.
if len(types) != len(dest) { if len(types) != len(dest) {
panic(util.AssertErr()) panic(util.AssertErr())
@@ -668,30 +658,95 @@ func (s *Stmt) Columns(dest ...any) error {
for i := range dest { for i := range dest {
switch types[i] { switch types[i] {
case byte(INTEGER): case byte(INTEGER):
dest[i] = util.Read64[int64](s.c.mod, dataPtr) dest[i] = util.Read64[int64](s.c.mod, ptr)
case byte(FLOAT): case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, dataPtr) dest[i] = util.ReadFloat64(s.c.mod, ptr)
case byte(NULL): case byte(NULL):
dest[i] = nil dest[i] = nil
default: case byte(TEXT):
len := util.Read32[int32](s.c.mod, dataPtr+4) len := util.Read32[int32](s.c.mod, ptr+4)
if len != 0 { if len != 0 {
ptr := util.Read32[ptr_t](s.c.mod, dataPtr) ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(len)) buf := util.View(s.c.mod, ptr, int64(len))
if types[i] == byte(TEXT) { dest[i] = string(buf)
dest[i] = string(buf)
} else {
dest[i] = buf
}
} else { } else {
if types[i] == byte(TEXT) { dest[i] = ""
dest[i] = "" }
} else { case byte(BLOB):
dest[i] = []byte{} len := util.Read32[int32](s.c.mod, ptr+4)
} if len != 0 {
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(len))
tmp, _ := dest[i].([]byte)
dest[i] = append(tmp[:0], buf...)
} else {
dest[i], _ = dest[i].([]byte)
} }
} }
dataPtr += 8 ptr += 8
} }
return nil return nil
} }
// ColumnsRaw populates result columns into the provided slice.
// The slice must have [Stmt.ColumnCount] length.
//
// [INTEGER] columns will be retrieved as int64 values,
// [FLOAT] as float64, [NULL] as nil,
// [TEXT] and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) ColumnsRaw(dest ...any) error {
types, ptr, err := s.columns(int64(len(dest)))
if err != nil {
return err
}
// Avoid bounds checks on types below.
if len(types) != len(dest) {
panic(util.AssertErr())
}
for i := range dest {
switch types[i] {
case byte(INTEGER):
dest[i] = util.Read64[int64](s.c.mod, ptr)
case byte(FLOAT):
dest[i] = util.ReadFloat64(s.c.mod, ptr)
case byte(NULL):
dest[i] = nil
default:
len := util.Read32[int32](s.c.mod, ptr+4)
if len == 0 && types[i] == byte(BLOB) {
dest[i] = []byte{}
} else {
cap := len
if types[i] == byte(TEXT) {
cap++
}
ptr := util.Read32[ptr_t](s.c.mod, ptr)
buf := util.View(s.c.mod, ptr, int64(cap))[:len]
dest[i] = buf
}
}
ptr += 8
}
return nil
}
func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) {
defer s.c.arena.mark()()
typePtr := s.c.arena.new(count)
dataPtr := s.c.arena.new(count * 8)
rc := res_t(s.c.call("sqlite3_columns_go",
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
if rc == res_t(MISUSE) {
return nil, 0, MISUSE
}
if err := s.c.error(rc); err != nil {
return nil, 0, err
}
return util.View(s.c.mod, typePtr, count), dataPtr, nil
}

View File

@@ -664,6 +664,42 @@ func TestStmt_ColumnValue(t *testing.T) {
} }
} }
func TestStmt_Columns(t *testing.T) {
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
stmt, _, err := db.Prepare(`SELECT 0, 0.5, 'abc', x'cafe', NULL`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
if stmt.Step() {
var dest [5]any
if err := stmt.Columns(dest[:]...); err != nil {
t.Fatal(err)
}
if got := dest[0]; got != int64(0) {
t.Errorf("got %d, want 0", got)
}
if got := dest[1]; got != float64(0.5) {
t.Errorf("got %f, want 0.5", got)
}
if got := dest[2]; got != "abc" {
t.Errorf("got %q, want 'abc'", got)
}
if got := dest[3]; string(got.([]byte)) != "\xCA\xFE" {
t.Errorf("got %q, want x'cafe'", got)
}
if got := dest[4]; got != nil {
t.Errorf("got %q, want nil", got)
}
}
}
func TestStmt_Error(t *testing.T) { func TestStmt_Error(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping in short mode") t.Skip("skipping in short mode")

2
txn.go
View File

@@ -20,7 +20,7 @@ type Txn struct {
} }
// Begin starts a deferred transaction. // Begin starts a deferred transaction.
// Panics if a transaction is already in-progress. // It panics if a transaction is in-progress.
// For nested transactions, use [Conn.Savepoint]. // For nested transactions, use [Conn.Savepoint].
// //
// https://sqlite.org/lang_transaction.html // https://sqlite.org/lang_transaction.html

View File

@@ -139,7 +139,7 @@ func (v Value) Blob(buf []byte) []byte {
// https://sqlite.org/c3ref/value_blob.html // https://sqlite.org/c3ref/value_blob.html
func (v Value) RawText() []byte { func (v Value) RawText() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected())) ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected()))
return v.rawBytes(ptr) return v.rawBytes(ptr, 1)
} }
// RawBlob returns the value as a []byte. // RawBlob returns the value as a []byte.
@@ -149,16 +149,16 @@ func (v Value) RawText() []byte {
// https://sqlite.org/c3ref/value_blob.html // https://sqlite.org/c3ref/value_blob.html
func (v Value) RawBlob() []byte { func (v Value) RawBlob() []byte {
ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected())) ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected()))
return v.rawBytes(ptr) return v.rawBytes(ptr, 0)
} }
func (v Value) rawBytes(ptr ptr_t) []byte { func (v Value) rawBytes(ptr ptr_t, nul int32) []byte {
if ptr == 0 { if ptr == 0 {
return nil return nil
} }
n := int32(v.c.call("sqlite3_value_bytes", v.protected())) n := int32(v.c.call("sqlite3_value_bytes", v.protected()))
return util.View(v.c.mod, ptr, int64(n)) return util.View(v.c.mod, ptr, int64(n+nul))[:n]
} }
// Pointer gets the pointer associated with this value, // Pointer gets the pointer associated with this value,