diff --git a/conn.go b/conn.go index 7e88d8c..a7eca16 100644 --- a/conn.go +++ b/conn.go @@ -444,20 +444,27 @@ func (c *Conn) Status(op DBStatus, reset bool) (current, highwater int, err erro // https://sqlite.org/c3ref/table_column_metadata.html func (c *Conn) TableColumnMetadata(schema, table, column string) (declType, collSeq string, notNull, primaryKey, autoInc bool, err error) { defer c.arena.mark()() - - var schemaPtr, columnPtr ptr_t - declTypePtr := c.arena.new(ptrlen) - collSeqPtr := c.arena.new(ptrlen) - notNullPtr := c.arena.new(ptrlen) - autoIncPtr := c.arena.new(ptrlen) - primaryKeyPtr := c.arena.new(ptrlen) + var ( + declTypePtr ptr_t + collSeqPtr ptr_t + notNullPtr ptr_t + primaryKeyPtr ptr_t + autoIncPtr ptr_t + columnPtr ptr_t + schemaPtr ptr_t + ) + if column != "" { + declTypePtr = c.arena.new(ptrlen) + collSeqPtr = c.arena.new(ptrlen) + notNullPtr = c.arena.new(ptrlen) + primaryKeyPtr = c.arena.new(ptrlen) + autoIncPtr = c.arena.new(ptrlen) + columnPtr = c.arena.string(column) + } if schema != "" { schemaPtr = c.arena.string(schema) } tablePtr := c.arena.string(table) - if column != "" { - columnPtr = c.arena.string(column) - } rc := res_t(c.call("sqlite3_table_column_metadata", stk_t(c.handle), stk_t(schemaPtr), stk_t(tablePtr), stk_t(columnPtr), diff --git a/driver/driver.go b/driver/driver.go index 37b5ac5..6cfb8cd 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -607,14 +607,24 @@ func (r resultRowsAffected) RowsAffected() (int64, error) { type scantype byte const ( - _ANY scantype = iota - _INT scantype = scantype(sqlite3.INTEGER) - _REAL scantype = scantype(sqlite3.FLOAT) - _TEXT scantype = scantype(sqlite3.TEXT) - _BLOB scantype = scantype(sqlite3.BLOB) - _NULL scantype = scantype(sqlite3.NULL) - _BOOL scantype = iota + _ANY scantype = iota + _INT + _REAL + _TEXT + _BLOB + _NULL + _BOOL _TIME + _NOT_NULL +) + +var ( + _ [0]struct{} = [scantype(sqlite3.INTEGER) - _INT]struct{}{} + _ [0]struct{} = [scantype(sqlite3.FLOAT) - _REAL]struct{}{} + _ [0]struct{} = [scantype(sqlite3.TEXT) - _TEXT]struct{}{} + _ [0]struct{} = [scantype(sqlite3.BLOB) - _BLOB]struct{}{} + _ [0]struct{} = [scantype(sqlite3.NULL) - _NULL]struct{}{} + _ [0]struct{} = [_NOT_NULL & (_NOT_NULL - 1)]struct{}{} ) func scanFromDecl(decl string) scantype { @@ -644,8 +654,8 @@ type rows struct { *stmt names []string types []string - nulls []bool scans []scantype + dest []driver.Value } var ( @@ -675,34 +685,36 @@ func (r *rows) Columns() []string { func (r *rows) scanType(index int) scantype { if r.scans == nil { - count := r.Stmt.ColumnCount() + count := len(r.names) scans := make([]scantype, count) for i := range scans { scans[i] = scanFromDecl(strings.ToUpper(r.Stmt.ColumnDeclType(i))) } r.scans = scans } - return r.scans[index] + return r.scans[index] &^ _NOT_NULL } func (r *rows) loadColumnMetadata() { - if r.nulls == nil { + if r.types == nil { c := r.Stmt.Conn() - count := r.Stmt.ColumnCount() - nulls := make([]bool, count) + count := len(r.names) types := make([]string, count) scans := make([]scantype, count) - for i := range nulls { + for i := range types { + var notnull bool if col := r.Stmt.ColumnOriginName(i); col != "" { - types[i], _, nulls[i], _, _, _ = c.TableColumnMetadata( + types[i], _, notnull, _, _, _ = c.TableColumnMetadata( r.Stmt.ColumnDatabaseName(i), r.Stmt.ColumnTableName(i), col) types[i] = strings.ToUpper(types[i]) scans[i] = scanFromDecl(types[i]) + if notnull { + scans[i] |= _NOT_NULL + } } } - r.nulls = nulls r.types = types r.scans = scans } @@ -721,15 +733,13 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string { func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { r.loadColumnMetadata() - if r.nulls[index] { - return false, true - } - return true, false + nullable = r.scans[index]&^_NOT_NULL == 0 + return nullable, !nullable } func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { r.loadColumnMetadata() - scan := r.scans[index] + scan := r.scans[index] &^ _NOT_NULL if r.Stmt.Busy() { // SQLite is dynamically typed and we now have a row. @@ -772,6 +782,7 @@ func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { } func (r *rows) Next(dest []driver.Value) error { + r.dest = nil c := r.Stmt.Conn() if old := c.SetInterrupt(r.ctx); old != r.ctx { defer c.SetInterrupt(old) @@ -829,10 +840,11 @@ func (r *rows) Next(dest []driver.Value) error { } } } + r.dest = dest return nil } -func (r *rows) ScanColumn(dest any, index int) error { +func (r *rows) ScanColumn(dest any, index int) (err error) { // notest // Go 1.26 var tm *time.Time var ok *bool @@ -848,10 +860,13 @@ func (r *rows) ScanColumn(dest any, index int) error { default: return driver.ErrSkip } - *tm = r.Stmt.ColumnTime(index, r.tmRead) - err := r.Stmt.Err() - if ok != nil && err == nil { - *ok = r.stmt.ColumnType(index) != sqlite3.NULL + value := r.dest[index] + *tm, err = r.tmRead.Decode(value) + if ok != nil { + *ok = err == nil + if value == nil { + return nil + } } return err } diff --git a/driver/driver_test.go b/driver/driver_test.go index cc8e8b3..fad2ac3 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -8,6 +8,7 @@ import ( "math" "net/url" "reflect" + "strings" "testing" "time" @@ -33,7 +34,7 @@ func Test_Open_error(t *testing.T) { func Test_Open_dir(t *testing.T) { t.Parallel() - db, err := sql.Open("sqlite3", ".") + db, err := Open(".") if err != nil { t.Fatal(err) } @@ -54,7 +55,7 @@ func Test_Open_pragma(t *testing.T) { "_pragma": {"busy_timeout(1000)"}, }) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -76,7 +77,7 @@ func Test_Open_pragma_invalid(t *testing.T) { "_pragma": {"busy_timeout 1000"}, }) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -105,7 +106,7 @@ func Test_Open_txLock(t *testing.T) { "_pragma": {"busy_timeout(1000)"}, }) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -140,7 +141,7 @@ func Test_Open_txLock_invalid(t *testing.T) { "_txlock": {"xclusive"}, }) - _, err := sql.Open("sqlite3", tmp+"_txlock=xclusive") + _, err := Open(tmp) if err == nil { t.Fatal("want error") } @@ -156,7 +157,7 @@ func Test_BeginTx(t *testing.T) { "_pragma": {"busy_timeout(0)"}, }) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -200,7 +201,7 @@ func Test_nested_context(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -258,7 +259,7 @@ func Test_Prepare(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -299,7 +300,7 @@ func Test_QueryRow_named(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -349,7 +350,7 @@ func Test_QueryRow_blob_null(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -388,7 +389,7 @@ func Test_time(t *testing.T) { "_timefmt": {fmt}, }) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -433,7 +434,7 @@ func Test_ColumnType_ScanType(t *testing.T) { t.Parallel() tmp := memdb.TestDB(t) - db, err := sql.Open("sqlite3", tmp) + db, err := Open(tmp) if err != nil { t.Fatal(err) } @@ -520,6 +521,39 @@ func Test_ColumnType_ScanType(t *testing.T) { } } +func Test_rows_ScanColumn(t *testing.T) { + t.Parallel() + tmp := memdb.TestDB(t) + + db, err := Open(tmp) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var tm time.Time + err = db.QueryRow(`SELECT NULL`).Scan(&tm) + if err == nil { + t.Error("want error") + } + // Go 1.26 + err = db.QueryRow(`SELECT datetime()`).Scan(&tm) + if err != nil && !strings.HasPrefix(err.Error(), "sql: Scan error") { + t.Error(err) + } + + var nt sql.NullTime + err = db.QueryRow(`SELECT NULL`).Scan(&nt) + if err != nil { + t.Error(err) + } + // Go 1.26 + err = db.QueryRow(`SELECT datetime()`).Scan(&nt) + if err != nil && !strings.HasPrefix(err.Error(), "sql: Scan error") { + t.Error(err) + } +} + func Benchmark_loop(b *testing.B) { db, err := Open(":memory:") if err != nil { @@ -533,8 +567,7 @@ func Benchmark_loop(b *testing.B) { b.Fatal(err) } - b.ResetTimer() - for range b.N { + for b.Loop() { _, err := db.ExecContext(b.Context(), `WITH RECURSIVE c(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM c WHERE x < 1000000) SELECT x FROM c;`) if err != nil { diff --git a/sqlite3/libc/libc_test.go b/sqlite3/libc/libc_test.go index 2b52091..551616d 100644 --- a/sqlite3/libc/libc_test.go +++ b/sqlite3/libc/libc_test.go @@ -75,8 +75,7 @@ func Benchmark_memset(b *testing.B) { clear(memory) b.SetBytes(size) - b.ResetTimer() - for range b.N { + for b.Loop() { call(memset, ptr1, 3, size) } } @@ -86,8 +85,7 @@ func Benchmark_memcpy(b *testing.B) { fill(memory[ptr2:ptr2+size], 5) b.SetBytes(size) - b.ResetTimer() - for range b.N { + for b.Loop() { call(memcpy, ptr1, ptr2, size) } } @@ -97,8 +95,7 @@ func Benchmark_strlen(b *testing.B) { fill(memory[ptr1:ptr1+size-1], 5) b.SetBytes(size) - b.ResetTimer() - for range b.N { + for b.Loop() { call(strlen, ptr1) } } @@ -109,8 +106,7 @@ func Benchmark_memchr(b *testing.B) { fill(memory[ptr1+size/2:ptr1+size], 5) b.SetBytes(size/2 + 1) - b.ResetTimer() - for range b.N { + for b.Loop() { call(memchr, ptr1, 5, size) } } @@ -121,8 +117,7 @@ func Benchmark_strchr(b *testing.B) { fill(memory[ptr1+size/2:ptr1+size-1], 5) b.SetBytes(size/2 + 1) - b.ResetTimer() - for range b.N { + for b.Loop() { call(strchr, ptr1, 5) } } @@ -133,8 +128,7 @@ func Benchmark_strrchr(b *testing.B) { fill(memory[ptr1+size/2:ptr1+size-1], 7) b.SetBytes(size/2 + 1) - b.ResetTimer() - for range b.N { + for b.Loop() { call(strrchr, ptr1, 5) } } @@ -146,8 +140,7 @@ func Benchmark_memcmp(b *testing.B) { fill(memory[ptr2+size/2:ptr2+size], 5) b.SetBytes(size/2 + 1) - b.ResetTimer() - for range b.N { + for b.Loop() { call(memcmp, ptr1, ptr2, size) } } @@ -162,8 +155,7 @@ func Benchmark_strspn(b *testing.B) { memory[ptr2+3] = 9 b.SetBytes(size) - b.ResetTimer() - for range b.N { + for b.Loop() { call(strspn, ptr1, ptr2) } } @@ -176,8 +168,7 @@ func Benchmark_strcspn(b *testing.B) { memory[ptr2+1] = 9 b.SetBytes(size) - b.ResetTimer() - for range b.N { + for b.Loop() { call(strcspn, ptr1, ptr2) } } diff --git a/vfs/adiantum/adiantum_test.go b/vfs/adiantum/adiantum_test.go index 3b66df2..41ef4e1 100644 --- a/vfs/adiantum/adiantum_test.go +++ b/vfs/adiantum/adiantum_test.go @@ -54,9 +54,8 @@ func Test_fileformat(t *testing.T) { func Benchmark_nokey(b *testing.B) { tmp := filepath.Join(b.TempDir(), "test.db") sqlite3.Initialize() - b.ResetTimer() - for range b.N { + for b.Loop() { db, err := sqlite3.Open("file:" + filepath.ToSlash(tmp) + "?nolock=1") if err != nil { b.Fatal(err) @@ -68,9 +67,8 @@ func Benchmark_nokey(b *testing.B) { func Benchmark_hexkey(b *testing.B) { tmp := filepath.Join(b.TempDir(), "test.db") sqlite3.Initialize() - b.ResetTimer() - for range b.N { + for b.Loop() { db, err := sqlite3.Open("file:" + filepath.ToSlash(tmp) + "?nolock=1" + "&vfs=adiantum&hexkey=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") if err != nil { @@ -83,9 +81,8 @@ func Benchmark_hexkey(b *testing.B) { func Benchmark_textkey(b *testing.B) { tmp := filepath.Join(b.TempDir(), "test.db") sqlite3.Initialize() - b.ResetTimer() - for range b.N { + for b.Loop() { db, err := sqlite3.Open("file:" + filepath.ToSlash(tmp) + "?nolock=1" + "&vfs=adiantum&textkey=correct+horse+battery+staple") if err != nil { diff --git a/vfs/xts/aes_test.go b/vfs/xts/aes_test.go index ecc0309..e93050b 100644 --- a/vfs/xts/aes_test.go +++ b/vfs/xts/aes_test.go @@ -54,9 +54,8 @@ func Test_fileformat(t *testing.T) { func Benchmark_nokey(b *testing.B) { tmp := filepath.Join(b.TempDir(), "test.db") sqlite3.Initialize() - b.ResetTimer() - for range b.N { + for b.Loop() { db, err := sqlite3.Open("file:" + filepath.ToSlash(tmp) + "?nolock=1") if err != nil { b.Fatal(err) @@ -68,9 +67,8 @@ func Benchmark_nokey(b *testing.B) { func Benchmark_hexkey(b *testing.B) { tmp := filepath.Join(b.TempDir(), "test.db") sqlite3.Initialize() - b.ResetTimer() - for range b.N { + for b.Loop() { db, err := sqlite3.Open("file:" + filepath.ToSlash(tmp) + "?nolock=1" + "&vfs=xts&hexkey=e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") if err != nil { @@ -83,9 +81,8 @@ func Benchmark_hexkey(b *testing.B) { func Benchmark_textkey(b *testing.B) { tmp := filepath.Join(b.TempDir(), "test.db") sqlite3.Initialize() - b.ResetTimer() - for range b.N { + for b.Loop() { db, err := sqlite3.Open("file:" + filepath.ToSlash(tmp) + "?nolock=1" + "&vfs=xts&textkey=correct+horse+battery+staple") if err != nil {