Scan improvements.

This commit is contained in:
Nuno Cruces
2025-09-24 13:56:50 +01:00
parent 08f9fc758a
commit 5cf06c45f7
6 changed files with 120 additions and 80 deletions

View File

@@ -607,14 +607,24 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
type scantype byte
const (
_ANY scantype = iota
_INT scantype = scantype(sqlite3.INTEGER)
_REAL scantype = scantype(sqlite3.FLOAT)
_TEXT scantype = scantype(sqlite3.TEXT)
_BLOB scantype = scantype(sqlite3.BLOB)
_NULL scantype = scantype(sqlite3.NULL)
_BOOL scantype = iota
_ANY scantype = iota
_INT
_REAL
_TEXT
_BLOB
_NULL
_BOOL
_TIME
_NOT_NULL
)
var (
_ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{}
_ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{}
_ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{}
_ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{}
_ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{}
_ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{}
)
func scanFromDecl(decl string) scantype {
@@ -644,8 +654,8 @@ type rows struct {
*stmt
names []string
types []string
nulls []bool
scans []scantype
dest []driver.Value
}
var (
@@ -675,34 +685,36 @@ func (r *rows) Columns() []string {
func (r *rows) scanType(index int) scantype {
if r.scans == nil {
count := r.Stmt.ColumnCount()
count := len(r.names)
scans := make([]scantype, count)
for i := range scans {
scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i)))
}
r.scans = scans
}
return r.scans[index]
return r.scans[index] &^ _NOT_NULL
}
func (r *rows) loadColumnMetadata() {
if r.nulls == nil {
if r.types == nil {
c := r.Stmt.Conn()
count := r.Stmt.ColumnCount()
nulls := make([]bool, count)
count := len(r.names)
types := make([]string, count)
scans := make([]scantype, count)
for i := range nulls {
for i := range types {
var notnull bool
if col := r.Stmt.ColumnOriginName(i); col != "" {
types[i], _, nulls[i], _, _, _ = c.TableColumnMetadata(
types[i], _, notnull, _, _, _ = c.TableColumnMetadata(
r.Stmt.ColumnDatabaseName(i),
r.Stmt.ColumnTableName(i),
col)
types[i] = strings.ToUpper(types[i])
scans[i] = scanFromDecl(types[i])
if notnull {
scans[i] |= _NOT_NULL
}
}
}
r.nulls = nulls
r.types = types
r.scans = scans
}
@@ -721,15 +733,13 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string {
func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) {
r.loadColumnMetadata()
if r.nulls[index] {
return false, true
}
return true, false
nullable = r.scans[index]&^_NOT_NULL == 0
return nullable, !nullable
}
func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
r.loadColumnMetadata()
scan := r.scans[index]
scan := r.scans[index] &^ _NOT_NULL
if r.Stmt.Busy() {
// SQLite is dynamically typed and we now have a row.
@@ -772,6 +782,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) {
}
func (r *rows) Next(dest []driver.Value) error {
r.dest = nil
c := r.Stmt.Conn()
if old := c.SetInterrupt(r.ctx); old != r.ctx {
defer c.SetInterrupt(old)
@@ -829,10 +840,11 @@ func (r *rows) Next(dest []driver.Value) error {
}
}
}
r.dest = dest
return nil
}
func (r *rows) ScanColumn(dest any, index int) error {
func (r *rows) ScanColumn(dest any, index int) (err error) {
// notest // Go 1.26
var tm *time.Time
var ok *bool
@@ -848,10 +860,13 @@ func (r *rows) ScanColumn(dest any, index int) error {
default:
return driver.ErrSkip
}
*tm = r.Stmt.ColumnTime(index, r.tmRead)
err := r.Stmt.Err()
if ok != nil && err == nil {
*ok = r.stmt.ColumnType(index) != sqlite3.NULL
value := r.dest[index]
*tm, err = r.tmRead.Decode(value)
if ok != nil {
*ok = err == nil
if value == nil {
return nil
}
}
return err
}

View File

@@ -8,6 +8,7 @@ import (
"math"
"net/url"
"reflect"
"strings"
"testing"
"time"
@@ -33,7 +34,7 @@ func Test_Open_error(t *testing.T) {
func Test_Open_dir(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", ".")
db, err := Open(".")
if err != nil {
t.Fatal(err)
}
@@ -54,7 +55,7 @@ func Test_Open_pragma(t *testing.T) {
"_pragma": {"busy_timeout(1000)"},
})
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -76,7 +77,7 @@ func Test_Open_pragma_invalid(t *testing.T) {
"_pragma": {"busy_timeout 1000"},
})
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -105,7 +106,7 @@ func Test_Open_txLock(t *testing.T) {
"_pragma": {"busy_timeout(1000)"},
})
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -140,7 +141,7 @@ func Test_Open_txLock_invalid(t *testing.T) {
"_txlock": {"xclusive"},
})
_, err := sql.Open("sqlite3", tmp+"_txlock=xclusive")
_, err := Open(tmp)
if err == nil {
t.Fatal("want error")
}
@@ -156,7 +157,7 @@ func Test_BeginTx(t *testing.T) {
"_pragma": {"busy_timeout(0)"},
})
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -200,7 +201,7 @@ func Test_nested_context(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -258,7 +259,7 @@ func Test_Prepare(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -299,7 +300,7 @@ func Test_QueryRow_named(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -349,7 +350,7 @@ func Test_QueryRow_blob_null(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -388,7 +389,7 @@ func Test_time(t *testing.T) {
"_timefmt": {fmt},
})
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -433,7 +434,7 @@ func Test_ColumnType_ScanType(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := sql.Open("sqlite3", tmp)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
@@ -520,6 +521,39 @@ func Test_ColumnType_ScanType(t *testing.T) {
}
}
func Test_rows_ScanColumn(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := Open(tmp)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var tm time.Time
err = db.QueryRow(`SELECT NULL`).Scan(&tm)
if err == nil {
t.Error("want error")
}
// Go 1.26
err = db.QueryRow(`SELECT datetime()`).Scan(&tm)
if err != nil && !strings.HasPrefix(err.Error(), "sql: Scan error") {
t.Error(err)
}
var nt sql.NullTime
err = db.QueryRow(`SELECT NULL`).Scan(&nt)
if err != nil {
t.Error(err)
}
// Go 1.26
err = db.QueryRow(`SELECT datetime()`).Scan(&nt)
if err != nil && !strings.HasPrefix(err.Error(), "sql: Scan error") {
t.Error(err)
}
}
func Benchmark_loop(b *testing.B) {
db, err := Open(":memory:")
if err != nil {
@@ -533,8 +567,7 @@ func Benchmark_loop(b *testing.B) {
b.Fatal(err)
}
b.ResetTimer()
for range b.N {
for b.Loop() {
_, err := db.ExecContext(b.Context(),
`WITH RECURSIVE c(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM c WHERE x < 1000000) SELECT x FROM c;`)
if err != nil {