diff --git a/driver/driver.go b/driver/driver.go index 9848d81..cd464f8 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -118,6 +118,10 @@ func (s stmt) Query(args []driver.Value) (driver.Rows, error) { err = s.stmt.BindBlob(i+1, a) case time.Time: err = s.stmt.BindText(i+1, a.Format(time.RFC3339Nano)) + case nil: + err = s.stmt.BindNull(i + 1) + default: + panic(assertErr) } if err != nil { return nil, err @@ -167,9 +171,13 @@ func (r rows) Next(dest []driver.Value) error { case sqlite3.FLOAT: dest[i] = r.s.ColumnFloat(i) case sqlite3.TEXT: - dest[i] = r.s.ColumnText(i) + dest[i] = maybeDate(r.s.ColumnText(i)) case sqlite3.BLOB: dest[i] = r.s.ColumnBlob(i, nil) + case sqlite3.NULL: + dest[i] = nil + default: + panic(assertErr) } } diff --git a/driver/error.go b/driver/error.go new file mode 100644 index 0000000..cc8b2fd --- /dev/null +++ b/driver/error.go @@ -0,0 +1,7 @@ +package driver + +type errorString string + +func (e errorString) Error() string { return string(e) } + +const assertErr = errorString("sqlite3: assertion failed") diff --git a/driver/time.go b/driver/time.go new file mode 100644 index 0000000..c8f592e --- /dev/null +++ b/driver/time.go @@ -0,0 +1,19 @@ +package driver + +import ( + "database/sql/driver" + "time" +) + +// Convert a string in [time.RFC3339Nano] format into a [time.Time] +// if it roundtrips back to the same string. +// This way times can be persisted to, and recovered from, the database, +// but if a string is needed, [database.sql] will recover the same string. +// TODO: optimize and fuzz test. +func maybeDate(text string) driver.Value { + date, err := time.Parse(time.RFC3339Nano, text) + if err == nil && date.Format(time.RFC3339Nano) == text { + return date + } + return text +}