mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
Avoid escaping times (#256)
This commit is contained in:
177
driver/driver.go
177
driver/driver.go
@@ -20,22 +20,45 @@
|
|||||||
// - a [serializable] transaction is always "immediate";
|
// - a [serializable] transaction is always "immediate";
|
||||||
// - a [read-only] transaction is always "deferred".
|
// - a [read-only] transaction is always "deferred".
|
||||||
//
|
//
|
||||||
|
// # Datatypes In SQLite
|
||||||
|
//
|
||||||
|
// SQLite is dynamically typed.
|
||||||
|
// Columns can mostly hold any value regardless of their declared type.
|
||||||
|
// SQLite supports most [driver.Value] types out of the box,
|
||||||
|
// but bool and [time.Time] require special care.
|
||||||
|
//
|
||||||
|
// Booleans can be stored on any column type and scanned back to a *bool.
|
||||||
|
// However, if scanned to a *any, booleans may either become an
|
||||||
|
// int64, string or bool, depending on the declared type of the column.
|
||||||
|
// If you use BOOLEAN for your column type,
|
||||||
|
// 1 and 0 will always scan as true and false.
|
||||||
|
//
|
||||||
// # Working with time
|
// # Working with time
|
||||||
//
|
//
|
||||||
|
// Time values can similarly be stored on any column type.
|
||||||
// The time encoding/decoding format can be specified using "_timefmt":
|
// The time encoding/decoding format can be specified using "_timefmt":
|
||||||
//
|
//
|
||||||
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
|
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
|
||||||
//
|
//
|
||||||
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
|
// Special values are: "auto" (the default), "sqlite", "rfc3339";
|
||||||
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
|
// - "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
|
||||||
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
|
// - "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
|
||||||
// - "rfc3339" encodes and decodes RFC 3339 only.
|
// - "rfc3339" encodes and decodes RFC 3339 only.
|
||||||
//
|
//
|
||||||
// If you encode as RFC 3339 (the default),
|
// You can also set "_timefmt" to an arbitrary [sqlite3.TimeFormat] or [time.Layout].
|
||||||
// consider using the TIME [collating sequence] to produce a time-ordered sequence.
|
|
||||||
//
|
//
|
||||||
// To scan values in other formats, [sqlite3.TimeFormat.Scanner] may be helpful.
|
// If you encode as RFC 3339 (the default),
|
||||||
// To bind values in other formats, [sqlite3.TimeFormat.Encode] them before binding.
|
// consider using the TIME [collating sequence] to produce time-ordered sequences.
|
||||||
|
//
|
||||||
|
// If you encode as RFC 3339 (the default),
|
||||||
|
// time values will scan back to a *time.Time unless your column type is TEXT.
|
||||||
|
// Otherwise, if scanned to a *any, time values may either become an
|
||||||
|
// int64, float64 or string, depending on the time format and declared type of the column.
|
||||||
|
// If you use DATE, TIME, DATETIME, or TIMESTAMP for your column type,
|
||||||
|
// "_timefmt" will be used to decode values.
|
||||||
|
//
|
||||||
|
// To scan values in custom formats, [sqlite3.TimeFormat.Scanner] may be helpful.
|
||||||
|
// To bind values in custom formats, [sqlite3.TimeFormat.Encode] them before binding.
|
||||||
//
|
//
|
||||||
// When using a custom time struct, you'll have to implement
|
// When using a custom time struct, you'll have to implement
|
||||||
// [database/sql/driver.Valuer] and [database/sql.Scanner].
|
// [database/sql/driver.Valuer] and [database/sql.Scanner].
|
||||||
@@ -48,7 +71,7 @@
|
|||||||
// The Scan method needs to take into account that the value it receives can be of differing types.
|
// The Scan method needs to take into account that the value it receives can be of differing types.
|
||||||
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
|
// It can already be a [time.Time], if the driver decoded the value according to "_timefmt" rules.
|
||||||
// Or it can be a: string, int64, float64, []byte, or nil,
|
// Or it can be a: string, int64, float64, []byte, or nil,
|
||||||
// depending on the column type and what whoever wrote the value.
|
// depending on the column type and whoever wrote the value.
|
||||||
// [sqlite3.TimeFormat.Decode] may help.
|
// [sqlite3.TimeFormat.Decode] may help.
|
||||||
//
|
//
|
||||||
// # Setting PRAGMAs
|
// # Setting PRAGMAs
|
||||||
@@ -595,6 +618,28 @@ const (
|
|||||||
_TIME
|
_TIME
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func scanFromDecl(decl string) scantype {
|
||||||
|
// These types are only used before we have rows,
|
||||||
|
// and otherwise as type hints.
|
||||||
|
// The first few ensure STRICT tables are strictly typed.
|
||||||
|
// The other two are type hints for booleans and time.
|
||||||
|
switch decl {
|
||||||
|
case "INT", "INTEGER":
|
||||||
|
return _INT
|
||||||
|
case "REAL":
|
||||||
|
return _REAL
|
||||||
|
case "TEXT":
|
||||||
|
return _TEXT
|
||||||
|
case "BLOB":
|
||||||
|
return _BLOB
|
||||||
|
case "BOOLEAN":
|
||||||
|
return _BOOL
|
||||||
|
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
|
||||||
|
return _TIME
|
||||||
|
}
|
||||||
|
return _ANY
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Ensure these interfaces are implemented:
|
// Ensure these interfaces are implemented:
|
||||||
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
|
_ driver.RowsColumnTypeDatabaseTypeName = &rows{}
|
||||||
@@ -619,6 +664,18 @@ func (r *rows) Columns() []string {
|
|||||||
return r.names
|
return r.names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *rows) scanType(index int) scantype {
|
||||||
|
if r.scans == nil {
|
||||||
|
count := r.Stmt.ColumnCount()
|
||||||
|
scans := make([]scantype, count)
|
||||||
|
for i := range scans {
|
||||||
|
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
|
||||||
|
}
|
||||||
|
r.scans = scans
|
||||||
|
}
|
||||||
|
return r.scans[index]
|
||||||
|
}
|
||||||
|
|
||||||
func (r *rows) loadColumnMetadata() {
|
func (r *rows) loadColumnMetadata() {
|
||||||
if r.nulls == nil {
|
if r.nulls == nil {
|
||||||
count := r.Stmt.ColumnCount()
|
count := r.Stmt.ColumnCount()
|
||||||
@@ -632,24 +689,7 @@ func (r *rows) loadColumnMetadata() {
|
|||||||
r.Stmt.ColumnTableName(i),
|
r.Stmt.ColumnTableName(i),
|
||||||
col)
|
col)
|
||||||
types[i] = strings.ToUpper(types[i])
|
types[i] = strings.ToUpper(types[i])
|
||||||
// These types are only used before we have rows,
|
scans[i] = scanFromDecl(types[i])
|
||||||
// and otherwise as type hints.
|
|
||||||
// The first few ensure STRICT tables are strictly typed.
|
|
||||||
// The other two are type hints for booleans and time.
|
|
||||||
switch types[i] {
|
|
||||||
case "INT", "INTEGER":
|
|
||||||
scans[i] = _INT
|
|
||||||
case "REAL":
|
|
||||||
scans[i] = _REAL
|
|
||||||
case "TEXT":
|
|
||||||
scans[i] = _TEXT
|
|
||||||
case "BLOB":
|
|
||||||
scans[i] = _BLOB
|
|
||||||
case "BOOLEAN":
|
|
||||||
scans[i] = _BOOL
|
|
||||||
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
|
|
||||||
scans[i] = _TIME
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
r.nulls = nulls
|
r.nulls = nulls
|
||||||
@@ -658,27 +698,15 @@ func (r *rows) loadColumnMetadata() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rows) declType(index int) string {
|
|
||||||
if r.types == nil {
|
|
||||||
count := r.Stmt.ColumnCount()
|
|
||||||
types := make([]string, count)
|
|
||||||
for i := range types {
|
|
||||||
types[i] = strings.ToUpper(r.Stmt.ColumnDeclType(i))
|
|
||||||
}
|
|
||||||
r.types = types
|
|
||||||
}
|
|
||||||
return r.types[index]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
|
func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
|
||||||
r.loadColumnMetadata()
|
r.loadColumnMetadata()
|
||||||
decltype := r.types[index]
|
decl := r.types[index]
|
||||||
if len := len(decltype); len > 0 && decltype[len-1] == ')' {
|
if len := len(decl); len > 0 && decl[len-1] == ')' {
|
||||||
if i := strings.LastIndexByte(decltype, '('); i >= 0 {
|
if i := strings.LastIndexByte(decl, '('); i >= 0 {
|
||||||
decltype = decltype[:i]
|
decl = decl[:i]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return strings.TrimSpace(decltype)
|
return strings.TrimSpace(decl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
|
||||||
@@ -745,36 +773,49 @@ func (r *rows) Next(dest []driver.Value) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
|
data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
|
||||||
err := r.Stmt.Columns(data...)
|
if err := r.Stmt.ColumnsRaw(data...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
for i := range dest {
|
for i := range dest {
|
||||||
if t, ok := r.decodeTime(i, dest[i]); ok {
|
scan := r.scanType(i)
|
||||||
dest[i] = t
|
switch v := dest[i].(type) {
|
||||||
}
|
case int64:
|
||||||
}
|
if scan == _BOOL {
|
||||||
return err
|
switch v {
|
||||||
}
|
case 1:
|
||||||
|
dest[i] = true
|
||||||
func (r *rows) decodeTime(i int, v any) (_ time.Time, ok bool) {
|
case 0:
|
||||||
switch v := v.(type) {
|
dest[i] = false
|
||||||
case int64, float64:
|
}
|
||||||
// could be a time value
|
continue
|
||||||
case string:
|
}
|
||||||
if r.tmWrite != "" && r.tmWrite != time.RFC3339 && r.tmWrite != time.RFC3339Nano {
|
case []byte:
|
||||||
|
if len(v) == cap(v) { // a BLOB
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if scan != _TEXT {
|
||||||
|
switch r.tmWrite {
|
||||||
|
case "", time.RFC3339, time.RFC3339Nano:
|
||||||
|
t, ok := maybeTime(v)
|
||||||
|
if ok {
|
||||||
|
dest[i] = t
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
dest[i] = string(v)
|
||||||
|
case float64:
|
||||||
break
|
break
|
||||||
|
default:
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
t, ok := maybeTime(v)
|
if scan == _TIME {
|
||||||
if ok {
|
t, err := r.tmRead.Decode(dest[i])
|
||||||
return t, true
|
if err == nil {
|
||||||
|
dest[i] = t
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
switch r.declType(i) {
|
return nil
|
||||||
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
|
|
||||||
// could be a time value
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t, err := r.tmRead.Decode(v)
|
|
||||||
return t, err == nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,5 @@
|
|||||||
//go:build linux || darwin || windows || freebsd || openbsd || netbsd || dragonfly || illumos || sqlite3_flock || sqlite3_dotlk
|
|
||||||
|
|
||||||
package driver_test
|
package driver_test
|
||||||
|
|
||||||
// Adapted from: https://go.dev/doc/tutorial/database-access
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
@@ -27,7 +23,7 @@ func Example_customTime() {
|
|||||||
_, err = db.Exec(`
|
_, err = db.Exec(`
|
||||||
CREATE TABLE data (
|
CREATE TABLE data (
|
||||||
id INTEGER PRIMARY KEY,
|
id INTEGER PRIMARY KEY,
|
||||||
date_time TEXT
|
date_time ANY
|
||||||
) STRICT;
|
) STRICT;
|
||||||
`)
|
`)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package driver
|
package driver
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"bytes"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
|
// Convert a string in [time.RFC3339Nano] format into a [time.Time]
|
||||||
// if it roundtrips back to the same string.
|
// if it roundtrips back to the same string.
|
||||||
// This way times can be persisted to, and recovered from, the database,
|
// This way times can be persisted to, and recovered from, the database,
|
||||||
// but if a string is needed, [database/sql] will recover the same string.
|
// but if a string is needed, [database/sql] will recover the same string.
|
||||||
func maybeTime(text string) (_ time.Time, _ bool) {
|
func maybeTime(text []byte) (_ time.Time, _ bool) {
|
||||||
// Weed out (some) values that can't possibly be
|
// Weed out (some) values that can't possibly be
|
||||||
// [time.RFC3339Nano] timestamps.
|
// [time.RFC3339Nano] timestamps.
|
||||||
if len(text) < len("2006-01-02T15:04:05Z") {
|
if len(text) < len("2006-01-02T15:04:05Z") {
|
||||||
@@ -21,8 +24,8 @@ func maybeTime(text string) (_ time.Time, _ bool) {
|
|||||||
|
|
||||||
// Slow path.
|
// Slow path.
|
||||||
var buf [len(time.RFC3339Nano)]byte
|
var buf [len(time.RFC3339Nano)]byte
|
||||||
date, err := time.Parse(time.RFC3339Nano, text)
|
date, err := time.Parse(time.RFC3339Nano, string(text))
|
||||||
if err == nil && text == string(date.AppendFormat(buf[:0], time.RFC3339Nano)) {
|
if err == nil && bytes.Equal(text, date.AppendFormat(buf[:0], time.RFC3339Nano)) {
|
||||||
return date, true
|
return date, true
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ func Fuzz_stringOrTime_1(f *testing.F) {
|
|||||||
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
f.Add("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
|
||||||
|
|
||||||
f.Fuzz(func(t *testing.T, str string) {
|
f.Fuzz(func(t *testing.T, str string) {
|
||||||
v, ok := maybeTime(str)
|
v, ok := maybeTime([]byte(str))
|
||||||
if ok {
|
if ok {
|
||||||
// Make sure times round-trip to the same string:
|
// Make sure times round-trip to the same string:
|
||||||
// https://pkg.go.dev/database/sql#Rows.Scan
|
// https://pkg.go.dev/database/sql#Rows.Scan
|
||||||
@@ -51,7 +51,7 @@ func Fuzz_stringOrTime_2(f *testing.F) {
|
|||||||
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
|
f.Add(int64(-763421161058), int64(222_222_222)) // twosday, year 22222BC
|
||||||
|
|
||||||
checkTime := func(t testing.TB, date time.Time) {
|
checkTime := func(t testing.TB, date time.Time) {
|
||||||
v, ok := maybeTime(date.Format(time.RFC3339Nano))
|
v, ok := maybeTime(date.AppendFormat(nil, time.RFC3339Nano))
|
||||||
if ok {
|
if ok {
|
||||||
// Make sure times round-trip to the same time:
|
// Make sure times round-trip to the same time:
|
||||||
if !v.Equal(date) {
|
if !v.Equal(date) {
|
||||||
|
|||||||
119
stmt.go
119
stmt.go
@@ -571,7 +571,7 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
|
|||||||
func (s *Stmt) ColumnRawText(col int) []byte {
|
func (s *Stmt) ColumnRawText(col int) []byte {
|
||||||
ptr := ptr_t(s.c.call("sqlite3_column_text",
|
ptr := ptr_t(s.c.call("sqlite3_column_text",
|
||||||
stk_t(s.handle), stk_t(col)))
|
stk_t(s.handle), stk_t(col)))
|
||||||
return s.columnRawBytes(col, ptr)
|
return s.columnRawBytes(col, ptr, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ColumnRawBlob returns the value of the result column as a []byte.
|
// ColumnRawBlob returns the value of the result column as a []byte.
|
||||||
@@ -583,10 +583,10 @@ func (s *Stmt) ColumnRawText(col int) []byte {
|
|||||||
func (s *Stmt) ColumnRawBlob(col int) []byte {
|
func (s *Stmt) ColumnRawBlob(col int) []byte {
|
||||||
ptr := ptr_t(s.c.call("sqlite3_column_blob",
|
ptr := ptr_t(s.c.call("sqlite3_column_blob",
|
||||||
stk_t(s.handle), stk_t(col)))
|
stk_t(s.handle), stk_t(col)))
|
||||||
return s.columnRawBytes(col, ptr)
|
return s.columnRawBytes(col, ptr, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
|
func (s *Stmt) columnRawBytes(col int, ptr ptr_t, nul int32) []byte {
|
||||||
if ptr == 0 {
|
if ptr == 0 {
|
||||||
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
|
rc := res_t(s.c.call("sqlite3_errcode", stk_t(s.c.handle)))
|
||||||
if rc != _ROW && rc != _DONE {
|
if rc != _ROW && rc != _DONE {
|
||||||
@@ -597,7 +597,7 @@ func (s *Stmt) columnRawBytes(col int, ptr ptr_t) []byte {
|
|||||||
|
|
||||||
n := int32(s.c.call("sqlite3_column_bytes",
|
n := int32(s.c.call("sqlite3_column_bytes",
|
||||||
stk_t(s.handle), stk_t(col)))
|
stk_t(s.handle), stk_t(col)))
|
||||||
return util.View(s.c.mod, ptr, int64(n))
|
return util.View(s.c.mod, ptr, int64(n+nul))[:n]
|
||||||
}
|
}
|
||||||
|
|
||||||
// ColumnJSON parses the JSON-encoded value of the result column
|
// ColumnJSON parses the JSON-encoded value of the result column
|
||||||
@@ -644,22 +644,12 @@ func (s *Stmt) ColumnValue(col int) Value {
|
|||||||
// [INTEGER] columns will be retrieved as int64 values,
|
// [INTEGER] columns will be retrieved as int64 values,
|
||||||
// [FLOAT] as float64, [NULL] as nil,
|
// [FLOAT] as float64, [NULL] as nil,
|
||||||
// [TEXT] as string, and [BLOB] as []byte.
|
// [TEXT] as string, and [BLOB] as []byte.
|
||||||
// Any []byte are owned by SQLite and may be invalidated by
|
|
||||||
// subsequent calls to [Stmt] methods.
|
|
||||||
func (s *Stmt) Columns(dest ...any) error {
|
func (s *Stmt) Columns(dest ...any) error {
|
||||||
defer s.c.arena.mark()()
|
types, ptr, err := s.columns(int64(len(dest)))
|
||||||
count := int64(len(dest))
|
if err != nil {
|
||||||
typePtr := s.c.arena.new(count)
|
|
||||||
dataPtr := s.c.arena.new(count * 8)
|
|
||||||
|
|
||||||
rc := res_t(s.c.call("sqlite3_columns_go",
|
|
||||||
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
|
|
||||||
if err := s.c.error(rc); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
types := util.View(s.c.mod, typePtr, count)
|
|
||||||
|
|
||||||
// Avoid bounds checks on types below.
|
// Avoid bounds checks on types below.
|
||||||
if len(types) != len(dest) {
|
if len(types) != len(dest) {
|
||||||
panic(util.AssertErr())
|
panic(util.AssertErr())
|
||||||
@@ -668,30 +658,95 @@ func (s *Stmt) Columns(dest ...any) error {
|
|||||||
for i := range dest {
|
for i := range dest {
|
||||||
switch types[i] {
|
switch types[i] {
|
||||||
case byte(INTEGER):
|
case byte(INTEGER):
|
||||||
dest[i] = util.Read64[int64](s.c.mod, dataPtr)
|
dest[i] = util.Read64[int64](s.c.mod, ptr)
|
||||||
case byte(FLOAT):
|
case byte(FLOAT):
|
||||||
dest[i] = util.ReadFloat64(s.c.mod, dataPtr)
|
dest[i] = util.ReadFloat64(s.c.mod, ptr)
|
||||||
case byte(NULL):
|
case byte(NULL):
|
||||||
dest[i] = nil
|
dest[i] = nil
|
||||||
default:
|
case byte(TEXT):
|
||||||
len := util.Read32[int32](s.c.mod, dataPtr+4)
|
len := util.Read32[int32](s.c.mod, ptr+4)
|
||||||
if len != 0 {
|
if len != 0 {
|
||||||
ptr := util.Read32[ptr_t](s.c.mod, dataPtr)
|
ptr := util.Read32[ptr_t](s.c.mod, ptr)
|
||||||
buf := util.View(s.c.mod, ptr, int64(len))
|
buf := util.View(s.c.mod, ptr, int64(len))
|
||||||
if types[i] == byte(TEXT) {
|
dest[i] = string(buf)
|
||||||
dest[i] = string(buf)
|
|
||||||
} else {
|
|
||||||
dest[i] = buf
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if types[i] == byte(TEXT) {
|
dest[i] = ""
|
||||||
dest[i] = ""
|
}
|
||||||
} else {
|
case byte(BLOB):
|
||||||
dest[i] = []byte{}
|
len := util.Read32[int32](s.c.mod, ptr+4)
|
||||||
}
|
if len != 0 {
|
||||||
|
ptr := util.Read32[ptr_t](s.c.mod, ptr)
|
||||||
|
buf := util.View(s.c.mod, ptr, int64(len))
|
||||||
|
tmp, _ := dest[i].([]byte)
|
||||||
|
dest[i] = append(tmp[:0], buf...)
|
||||||
|
} else {
|
||||||
|
dest[i], _ = dest[i].([]byte)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
dataPtr += 8
|
ptr += 8
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ColumnsRaw populates result columns into the provided slice.
|
||||||
|
// The slice must have [Stmt.ColumnCount] length.
|
||||||
|
//
|
||||||
|
// [INTEGER] columns will be retrieved as int64 values,
|
||||||
|
// [FLOAT] as float64, [NULL] as nil,
|
||||||
|
// [TEXT] and [BLOB] as []byte.
|
||||||
|
// Any []byte are owned by SQLite and may be invalidated by
|
||||||
|
// subsequent calls to [Stmt] methods.
|
||||||
|
func (s *Stmt) ColumnsRaw(dest ...any) error {
|
||||||
|
types, ptr, err := s.columns(int64(len(dest)))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid bounds checks on types below.
|
||||||
|
if len(types) != len(dest) {
|
||||||
|
panic(util.AssertErr())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range dest {
|
||||||
|
switch types[i] {
|
||||||
|
case byte(INTEGER):
|
||||||
|
dest[i] = util.Read64[int64](s.c.mod, ptr)
|
||||||
|
case byte(FLOAT):
|
||||||
|
dest[i] = util.ReadFloat64(s.c.mod, ptr)
|
||||||
|
case byte(NULL):
|
||||||
|
dest[i] = nil
|
||||||
|
default:
|
||||||
|
len := util.Read32[int32](s.c.mod, ptr+4)
|
||||||
|
if len == 0 && types[i] == byte(BLOB) {
|
||||||
|
dest[i] = []byte{}
|
||||||
|
} else {
|
||||||
|
cap := len
|
||||||
|
if types[i] == byte(TEXT) {
|
||||||
|
cap++
|
||||||
|
}
|
||||||
|
ptr := util.Read32[ptr_t](s.c.mod, ptr)
|
||||||
|
buf := util.View(s.c.mod, ptr, int64(cap))[:len]
|
||||||
|
dest[i] = buf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ptr += 8
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) columns(count int64) ([]byte, ptr_t, error) {
|
||||||
|
defer s.c.arena.mark()()
|
||||||
|
typePtr := s.c.arena.new(count)
|
||||||
|
dataPtr := s.c.arena.new(count * 8)
|
||||||
|
|
||||||
|
rc := res_t(s.c.call("sqlite3_columns_go",
|
||||||
|
stk_t(s.handle), stk_t(count), stk_t(typePtr), stk_t(dataPtr)))
|
||||||
|
if rc == res_t(MISUSE) {
|
||||||
|
return nil, 0, MISUSE
|
||||||
|
}
|
||||||
|
if err := s.c.error(rc); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.View(s.c.mod, typePtr, count), dataPtr, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -664,6 +664,42 @@ func TestStmt_ColumnValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStmt_Columns(t *testing.T) {
|
||||||
|
db, err := sqlite3.Open(":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
stmt, _, err := db.Prepare(`SELECT 0, 0.5, 'abc', x'cafe', NULL`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
if stmt.Step() {
|
||||||
|
var dest [5]any
|
||||||
|
if err := stmt.Columns(dest[:]...); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if got := dest[0]; got != int64(0) {
|
||||||
|
t.Errorf("got %d, want 0", got)
|
||||||
|
}
|
||||||
|
if got := dest[1]; got != float64(0.5) {
|
||||||
|
t.Errorf("got %f, want 0.5", got)
|
||||||
|
}
|
||||||
|
if got := dest[2]; got != "abc" {
|
||||||
|
t.Errorf("got %q, want 'abc'", got)
|
||||||
|
}
|
||||||
|
if got := dest[3]; string(got.([]byte)) != "\xCA\xFE" {
|
||||||
|
t.Errorf("got %q, want x'cafe'", got)
|
||||||
|
}
|
||||||
|
if got := dest[4]; got != nil {
|
||||||
|
t.Errorf("got %q, want nil", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStmt_Error(t *testing.T) {
|
func TestStmt_Error(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping in short mode")
|
t.Skip("skipping in short mode")
|
||||||
|
|||||||
2
txn.go
2
txn.go
@@ -20,7 +20,7 @@ type Txn struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Begin starts a deferred transaction.
|
// Begin starts a deferred transaction.
|
||||||
// Panics if a transaction is already in-progress.
|
// It panics if a transaction is in-progress.
|
||||||
// For nested transactions, use [Conn.Savepoint].
|
// For nested transactions, use [Conn.Savepoint].
|
||||||
//
|
//
|
||||||
// https://sqlite.org/lang_transaction.html
|
// https://sqlite.org/lang_transaction.html
|
||||||
|
|||||||
8
value.go
8
value.go
@@ -139,7 +139,7 @@ func (v Value) Blob(buf []byte) []byte {
|
|||||||
// https://sqlite.org/c3ref/value_blob.html
|
// https://sqlite.org/c3ref/value_blob.html
|
||||||
func (v Value) RawText() []byte {
|
func (v Value) RawText() []byte {
|
||||||
ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected()))
|
ptr := ptr_t(v.c.call("sqlite3_value_text", v.protected()))
|
||||||
return v.rawBytes(ptr)
|
return v.rawBytes(ptr, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RawBlob returns the value as a []byte.
|
// RawBlob returns the value as a []byte.
|
||||||
@@ -149,16 +149,16 @@ func (v Value) RawText() []byte {
|
|||||||
// https://sqlite.org/c3ref/value_blob.html
|
// https://sqlite.org/c3ref/value_blob.html
|
||||||
func (v Value) RawBlob() []byte {
|
func (v Value) RawBlob() []byte {
|
||||||
ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected()))
|
ptr := ptr_t(v.c.call("sqlite3_value_blob", v.protected()))
|
||||||
return v.rawBytes(ptr)
|
return v.rawBytes(ptr, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v Value) rawBytes(ptr ptr_t) []byte {
|
func (v Value) rawBytes(ptr ptr_t, nul int32) []byte {
|
||||||
if ptr == 0 {
|
if ptr == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
n := int32(v.c.call("sqlite3_value_bytes", v.protected()))
|
n := int32(v.c.call("sqlite3_value_bytes", v.protected()))
|
||||||
return util.View(v.c.mod, ptr, int64(n))
|
return util.View(v.c.mod, ptr, int64(n+nul))[:n]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pointer gets the pointer associated with this value,
|
// Pointer gets the pointer associated with this value,
|
||||||
|
|||||||
Reference in New Issue
Block a user