diff --git a/driver/driver.go b/driver/driver.go index ef08aa7..b8bf3f8 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -12,6 +12,18 @@ // // sql.Open("sqlite3", "file:demo.db?_txlock=immediate") // +// Possible values are: "deferred", "immediate", "exclusive". +// A [read-only] transaction is always "deferred", regardless of "_txlock". +// +// The time encoding/decoding format can be specified using "_timefmt": +// +// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite") +// +// Possible values are: "auto" (the default), "sqlite", "rfc3339"; +// "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite; +// "sqlite" encodes as SQLite and decodes any [format] supported by SQLite; +// "rfc3339" encodes and decodes RFC 3339 only. +// // [PRAGMA] statements can be specified using "_pragma": // // sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)") @@ -23,7 +35,9 @@ // // [URI]: https://sqlite.org/uri.html // [PRAGMA]: https://sqlite.org/pragma.html +// [format]: https://sqlite.org/lang_datefunc.html#time_values // [TRANSACTION]: https://sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions +// [read-only]: https://pkg.go.dev/database/sql#TxOptions package driver import ( @@ -43,7 +57,7 @@ import ( // This variable can be replaced with -ldflags: // -// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite" +// go build -ldflags="-X github.com/ncruces/go-sqlite3/driver.driverName=sqlite" var driverName = "sqlite3" func init() { @@ -81,23 +95,52 @@ func (sqlite) OpenConnector(name string) (driver.Connector, error) { func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) { c := connector{name: name, init: init} + + var txlock, timefmt string if strings.HasPrefix(name, "file:") { if _, after, ok := strings.Cut(name, "?"); ok { query, err := url.ParseQuery(after) if err != nil { return nil, err } - c.txlock = query.Get("_txlock") - c.pragmas = len(query["_pragma"]) > 0 + txlock = query.Get("_txlock") + timefmt = query.Get("_timefmt") + c.pragmas = query.Has("_pragma") } } + + switch txlock { + case "": + c.txBegin = "BEGIN" + case "deferred", "immediate", "exclusive": + c.txBegin = "BEGIN " + txlock + default: + return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", txlock) + } + + switch timefmt { + case "": + c.tmRead = sqlite3.TimeFormatAuto + c.tmWrite = sqlite3.TimeFormatDefault + case "sqlite": + c.tmRead = sqlite3.TimeFormatAuto + c.tmWrite = sqlite3.TimeFormat3 + case "rfc3339": + c.tmRead = sqlite3.TimeFormatDefault + c.tmWrite = sqlite3.TimeFormatDefault + default: + c.tmRead = sqlite3.TimeFormat(timefmt) + c.tmWrite = sqlite3.TimeFormat(timefmt) + } return &c, nil } type connector struct { init func(*sqlite3.Conn) error name string - txlock string + txBegin string + tmRead sqlite3.TimeFormat + tmWrite sqlite3.TimeFormat pragmas bool } @@ -106,7 +149,12 @@ func (n *connector) Driver() driver.Driver { } func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { - var c conn + c := &conn{ + txBegin: n.txBegin, + tmRead: n.tmRead, + tmWrite: n.tmWrite, + } + c.Conn, err = sqlite3.Open(n.name) if err != nil { return nil, err @@ -120,14 +168,6 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { old := c.Conn.SetInterrupt(ctx) defer c.Conn.SetInterrupt(old) - switch n.txlock { - case "": - c.txBegin = "BEGIN" - case "deferred", "immediate", "exclusive": - c.txBegin = "BEGIN " + n.txlock - default: - return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock) - } if !n.pragmas { err = c.Conn.Exec(`PRAGMA busy_timeout=60000`) if err != nil { @@ -155,7 +195,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) { return nil, err } } - return &c, nil + return c, nil } type conn struct { @@ -163,6 +203,8 @@ type conn struct { txBegin string txCommit string txRollback string + tmRead sqlite3.TimeFormat + tmWrite sqlite3.TimeFormat readOnly byte } @@ -250,7 +292,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e s.Close() return nil, util.TailErr } - return &stmt{s}, nil + return &stmt{Stmt: s, tmRead: c.tmRead, tmWrite: c.tmWrite}, nil } func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { @@ -261,7 +303,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name if savept, ok := ctx.(*saveptCtx); ok { // Called from driver.Savepoint. - savept.Savepoint = c.Savepoint() + savept.Savepoint = c.Conn.Savepoint() return resultRowsAffected(0), nil } @@ -282,6 +324,8 @@ func (*conn) CheckNamedValue(arg *driver.NamedValue) error { type stmt struct { *sqlite3.Stmt + tmWrite sqlite3.TimeFormat + tmRead sqlite3.TimeFormat } var ( @@ -333,7 +377,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv if err != nil { return nil, err } - return &rows{ctx, s.Stmt}, nil + return &rows{s, ctx}, nil } func (s *stmt) setupBindings(args []driver.NamedValue) error { @@ -372,7 +416,7 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error { case sqlite3.ZeroBlob: err = s.Stmt.BindZeroBlob(id, int64(a)) case time.Time: - err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault) + err = s.Stmt.BindTime(id, a, s.tmWrite) case interface{ Pointer() any }: err = s.Stmt.BindPointer(id, a.Pointer()) case interface{ JSON() any }: @@ -435,8 +479,8 @@ func (r resultRowsAffected) RowsAffected() (int64, error) { } type rows struct { - ctx context.Context - Stmt *sqlite3.Stmt + *stmt + ctx context.Context } func (r *rows) Close() error { @@ -475,6 +519,10 @@ func (r *rows) Next(dest []driver.Value) error { } for i := range dest { + if t, ok := r.decodeTime(i); ok { + dest[i] = t + continue + } switch r.Stmt.ColumnType(i) { case sqlite3.INTEGER: dest[i] = r.Stmt.ColumnInt64(i) @@ -493,3 +541,22 @@ func (r *rows) Next(dest []driver.Value) error { return r.Stmt.Err() } + +func (s *stmt) decodeTime(i int) (_ time.Time, _ bool) { + if s.tmRead == "" { + return + } + switch s.Stmt.ColumnType(i) { + case sqlite3.INTEGER, sqlite3.FLOAT, sqlite3.TEXT: + // maybe + default: + return + } + switch strings.ToUpper(s.Stmt.ColumnDeclType(i)) { + case "DATE", "TIME", "DATETIME", "TIMESTAMP": + // maybe + default: + return + } + return s.Stmt.ColumnTime(i, s.tmRead), s.Stmt.Err() == nil +} diff --git a/driver/driver_test.go b/driver/driver_test.go index cf32f74..2b09930 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -114,13 +114,7 @@ func Test_Open_txLock(t *testing.T) { func Test_Open_txLock_invalid(t *testing.T) { t.Parallel() - db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - _, err = db.Conn(context.TODO()) + _, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive") if err == nil { t.Fatal("want error") } diff --git a/time.go b/time.go index e5cc3ed..47ac672 100644 --- a/time.go +++ b/time.go @@ -20,7 +20,7 @@ type TimeFormat string // TimeFormats recognized by SQLite to encode/decode time values. // -// https://sqlite.org/lang_datefunc.html +// https://sqlite.org/lang_datefunc.html#time_values const ( TimeFormatDefault TimeFormat = "" // time.RFC3339Nano