From 36bbd674c21b75661f552bf788b1fcd55bfbdfb6 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Wed, 11 Dec 2024 18:35:50 +0000 Subject: [PATCH] Add ColumnTypeScanType to driver (#199). --- driver/driver.go | 85 +++++++++++++++++++++++++++++++++-- driver/driver_test.go | 102 ++++++++++++++++++++++++++++++++++++++++++ driver/time.go | 4 +- 3 files changed, 185 insertions(+), 6 deletions(-) diff --git a/driver/driver.go b/driver/driver.go index af04af6..477e9a9 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -81,6 +81,7 @@ import ( "fmt" "io" "net/url" + "reflect" "strings" "time" "unsafe" @@ -579,8 +580,22 @@ type rows struct { names []string types []string nulls []bool + scans []scantype } +type scantype byte + +const ( + _ANY scantype = iota + _INT scantype = scantype(sqlite3.INTEGER) + _REAL scantype = scantype(sqlite3.FLOAT) + _TEXT scantype = scantype(sqlite3.TEXT) + _BLOB scantype = scantype(sqlite3.BLOB) + _NULL scantype = scantype(sqlite3.NULL) + _BOOL scantype = iota + _TIME +) + var ( // Ensure these interfaces are implemented: _ driver.RowsColumnTypeDatabaseTypeName = &rows{} @@ -604,11 +619,12 @@ func (r *rows) Columns() []string { return r.names } -func (r *rows) loadTypes() { +func (r *rows) loadColumnMetadata() { if r.nulls == nil { count := r.Stmt.ColumnCount() nulls := make([]bool, count) types := make([]string, count) + scans := make([]scantype, count) for i := range nulls { if col := r.Stmt.ColumnOriginName(i); col != "" { types[i], _, nulls[i], _, _, _ = r.Stmt.Conn().TableColumnMetadata( @@ -616,10 +632,29 @@ func (r *rows) loadTypes() { r.Stmt.ColumnTableName(i), col) types[i] = strings.ToUpper(types[i]) + // These types are only used before we have rows, + // and otherwise as type hints. + // The first few ensure STRICT tables are strictly typed. + // The other two are type hints for booleans and time. + switch types[i] { + case "INT", "INTEGER": + scans[i] = _INT + case "REAL": + scans[i] = _REAL + case "TEXT": + scans[i] = _TEXT + case "BLOB": + scans[i] = _BLOB + case "BOOLEAN": + scans[i] = _BOOL + case "DATE", "TIME", "DATETIME", "TIMESTAMP": + scans[i] = _TIME + } } } r.nulls = nulls r.types = types + r.scans = scans } } @@ -636,7 +671,7 @@ func (r *rows) declType(index int) string { } func (r *rows) ColumnTypeDatabaseTypeName(index int) string { - r.loadTypes() + r.loadColumnMetadata() decltype := r.types[index] if len := len(decltype); len > 0 && decltype[len-1] == ')' { if i := strings.LastIndexByte(decltype, '('); i >= 0 { @@ -647,13 +682,57 @@ func (r *rows) ColumnTypeDatabaseTypeName(index int) string { } func (r *rows) ColumnTypeNullable(index int) (nullable, ok bool) { - r.loadTypes() + r.loadColumnMetadata() if r.nulls[index] { return false, true } return true, false } +func (r *rows) ColumnTypeScanType(index int) (typ reflect.Type) { + r.loadColumnMetadata() + scan := r.scans[index] + + if r.Stmt.Busy() { + // SQLite is dynamically typed and we now have a row. + // Always use the type of the value itself, + // unless the scan type is more specific + // and can scan the actual value. + val := scantype(r.Stmt.ColumnType(index)) + useValType := true + switch { + case scan == _TIME && val != _BLOB && val != _NULL: + t := r.Stmt.ColumnTime(index, r.tmRead) + useValType = t == time.Time{} + case scan == _BOOL && val == _INT: + i := r.Stmt.ColumnInt64(index) + useValType = i != 0 && i != 1 + case scan == _BLOB && val == _NULL: + useValType = false + } + if useValType { + scan = val + } + } + + switch scan { + case _INT: + return reflect.TypeOf(int64(0)) + case _REAL: + return reflect.TypeOf(float64(0)) + case _TEXT: + return reflect.TypeOf("") + case _BLOB: + return reflect.TypeOf([]byte{}) + case _BOOL: + return reflect.TypeOf(false) + case _TIME: + return reflect.TypeOf(time.Time{}) + default: + return reflect.TypeOf((*any)(nil)).Elem() + } +} + func (r *rows) Next(dest []driver.Value) error { old := r.Stmt.Conn().SetInterrupt(r.ctx) defer r.Stmt.Conn().SetInterrupt(old) diff --git a/driver/driver_test.go b/driver/driver_test.go index 97b331c..bef4bf6 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -7,6 +7,7 @@ import ( "errors" "math" "net/url" + "reflect" "testing" "time" @@ -365,3 +366,104 @@ func Test_time(t *testing.T) { }) } } + +func Test_ColumnType_ScanType(t *testing.T) { + var ( + INT = reflect.TypeOf(int64(0)) + REAL = reflect.TypeOf(float64(0)) + TEXT = reflect.TypeOf("") + BLOB = reflect.TypeOf([]byte{}) + BOOL = reflect.TypeOf(false) + TIME = reflect.TypeOf(time.Time{}) + ANY = reflect.TypeOf((*any)(nil)).Elem() + ) + + t.Parallel() + tmp := memdb.TestDB(t) + + db, err := sql.Open("sqlite3", tmp) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec(` + CREATE TABLE test ( + col_int INTEGER, + col_real REAL, + col_text TEXT, + col_blob BLOB, + col_bool BOOLEAN, + col_time DATETIME, + col_decimal DECIMAL + ); + INSERT INTO test VALUES + (1, 1, 1, 1, 1, 1, 1), + (2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0), + ('1', '1', '1', '1', '1', '1', '1'), + ('x', 'x', 'x', 'x', 'x', 'x', 'x'), + (x'', x'', x'', x'', x'', x'', x''), + ('2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z', + '2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z', '2006-01-02T15:04:05Z'), + (TRUE, TRUE, TRUE, TRUE, TRUE, TRUE, TRUE), + (NULL, NULL, NULL, NULL, NULL, NULL, NULL); + `) + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query(`SELECT * FROM test`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + cols, err := rows.ColumnTypes() + if err != nil { + t.Fatal(err) + } + + want := [][]reflect.Type{ + {INT, REAL, TEXT, BLOB, BOOL, TIME, ANY}, + {INT, REAL, TEXT, INT, BOOL, TIME, INT}, + {INT, REAL, TEXT, REAL, INT, TIME, INT}, + {INT, REAL, TEXT, TEXT, BOOL, TIME, INT}, + {TEXT, TEXT, TEXT, TEXT, TEXT, TEXT, TEXT}, + {BLOB, BLOB, BLOB, BLOB, BLOB, BLOB, BLOB}, + {TEXT, TEXT, TEXT, TEXT, TEXT, TIME, TEXT}, + {INT, REAL, TEXT, INT, BOOL, TIME, INT}, + {ANY, ANY, ANY, BLOB, ANY, ANY, ANY}, + } + for j, c := range cols { + got := c.ScanType() + if got != want[0][j] { + t.Errorf("want %v, got %v, at column %d", want[0][j], got, j) + } + } + + dest := make([]any, len(cols)) + for i := 1; rows.Next(); i++ { + cols, err := rows.ColumnTypes() + if err != nil { + t.Fatal(err) + } + + for j, c := range cols { + got := c.ScanType() + if got != want[i][j] { + t.Errorf("want %v, got %v, at row %d column %d", want[i][j], got, i, j) + } + dest[j] = reflect.New(got).Interface() + } + + err = rows.Scan(dest...) + if err != nil { + t.Error(err) + } + } + + err = rows.Err() + if err != nil { + t.Fatal(err) + } +} diff --git a/driver/time.go b/driver/time.go index 630a5b1..b3ebdd2 100644 --- a/driver/time.go +++ b/driver/time.go @@ -1,8 +1,6 @@ package driver -import ( - "time" -) +import "time" // Convert a string in [time.RFC3339Nano] format into a [time.Time] // if it roundtrips back to the same string.