diff --git a/tests/time_test.go b/tests/time_test.go index 156f355..eb28113 100644 --- a/tests/time_test.go +++ b/tests/time_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/driver" ) func TestTimeFormat_Encode(t *testing.T) { @@ -119,6 +120,38 @@ func TestTimeFormat_Decode(t *testing.T) { } } +func TestTimeFormat_Scanner(t *testing.T) { + t.Parallel() + + db, err := driver.Open(":memory:", nil) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec( + `CREATE TABLE IF NOT EXISTS test (col)`) + if err != nil { + t.Fatal(err) + } + + reference := time.Date(2013, 10, 7, 4, 23, 19, 120_000_000, time.FixedZone("", -4*3600)) + + _, err = db.Exec(`INSERT INTO test VALUES (?)`, sqlite3.TimeFormat7TZ.Encode(reference)) + if err != nil { + t.Fatal(err) + } + + var got time.Time + err = db.QueryRow("SELECT * FROM test").Scan(sqlite3.TimeFormatAuto.Scanner(&got)) + if err != nil { + t.Fatal(err) + } + if !got.Equal(reference) { + t.Errorf("got %v, want %v", got, reference) + } +} + func TestDB_timeCollation(t *testing.T) { t.Parallel() diff --git a/time.go b/time.go index 6fddc3e..5449bec 100644 --- a/time.go +++ b/time.go @@ -338,3 +338,19 @@ func (f TimeFormat) parseRelaxed(s string) (time.Time, error) { } return t, nil } + +// Scanner returns a [database/sql.Scanner] that +// decodes a time value into dest using this format. +func (f TimeFormat) Scanner(dest *time.Time) interface{ Scan(any) error } { + return timeScanner{dest, f} +} + +type timeScanner struct { + *time.Time + TimeFormat +} + +func (s timeScanner) Scan(src any) (err error) { + *s.Time, err = s.Decode(src) + return +}