Driver time formatting.

This commit is contained in:
Nuno Cruces
2023-12-07 03:09:37 -08:00
parent 089a0c0670
commit 7c820ede3c
3 changed files with 89 additions and 28 deletions

View File

@@ -12,6 +12,18 @@
//
// sql.Open("sqlite3", "file:demo.db?_txlock=immediate")
//
// Possible values are: "deferred", "immediate", "exclusive".
// A [read-only] transaction is always "deferred", regardless of "_txlock".
//
// The time encoding/decoding format can be specified using "_timefmt":
//
// sql.Open("sqlite3", "file:demo.db?_timefmt=sqlite")
//
// Possible values are: "auto" (the default), "sqlite", "rfc3339";
// "auto" encodes as RFC 3339 and decodes any [format] supported by SQLite;
// "sqlite" encodes as SQLite and decodes any [format] supported by SQLite;
// "rfc3339" encodes and decodes RFC 3339 only.
//
// [PRAGMA] statements can be specified using "_pragma":
//
// sql.Open("sqlite3", "file:demo.db?_pragma=busy_timeout(10000)")
@@ -23,7 +35,9 @@
//
// [URI]: https://sqlite.org/uri.html
// [PRAGMA]: https://sqlite.org/pragma.html
// [format]: https://sqlite.org/lang_datefunc.html#time_values
// [TRANSACTION]: https://sqlite.org/lang_transaction.html#deferred_immediate_and_exclusive_transactions
// [read-only]: https://pkg.go.dev/database/sql#TxOptions
package driver
import (
@@ -43,7 +57,7 @@ import (
// This variable can be replaced with -ldflags:
//
// go build -ldflags="-X github.com/ncruces/go-sqlite3.driverName=sqlite"
// go build -ldflags="-X github.com/ncruces/go-sqlite3/driver.driverName=sqlite"
var driverName = "sqlite3"
func init() {
@@ -81,23 +95,52 @@ func (sqlite) OpenConnector(name string) (driver.Connector, error) {
func newConnector(name string, init func(*sqlite3.Conn) error) (*connector, error) {
c := connector{name: name, init: init}
var txlock, timefmt string
if strings.HasPrefix(name, "file:") {
if _, after, ok := strings.Cut(name, "?"); ok {
query, err := url.ParseQuery(after)
if err != nil {
return nil, err
}
c.txlock = query.Get("_txlock")
c.pragmas = len(query["_pragma"]) > 0
txlock = query.Get("_txlock")
timefmt = query.Get("_timefmt")
c.pragmas = query.Has("_pragma")
}
}
switch txlock {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + txlock
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", txlock)
}
switch timefmt {
case "":
c.tmRead = sqlite3.TimeFormatAuto
c.tmWrite = sqlite3.TimeFormatDefault
case "sqlite":
c.tmRead = sqlite3.TimeFormatAuto
c.tmWrite = sqlite3.TimeFormat3
case "rfc3339":
c.tmRead = sqlite3.TimeFormatDefault
c.tmWrite = sqlite3.TimeFormatDefault
default:
c.tmRead = sqlite3.TimeFormat(timefmt)
c.tmWrite = sqlite3.TimeFormat(timefmt)
}
return &c, nil
}
type connector struct {
init func(*sqlite3.Conn) error
name string
txlock string
txBegin string
tmRead sqlite3.TimeFormat
tmWrite sqlite3.TimeFormat
pragmas bool
}
@@ -106,7 +149,12 @@ func (n *connector) Driver() driver.Driver {
}
func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
var c conn
c := &conn{
txBegin: n.txBegin,
tmRead: n.tmRead,
tmWrite: n.tmWrite,
}
c.Conn, err = sqlite3.Open(n.name)
if err != nil {
return nil, err
@@ -120,14 +168,6 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
old := c.Conn.SetInterrupt(ctx)
defer c.Conn.SetInterrupt(old)
switch n.txlock {
case "":
c.txBegin = "BEGIN"
case "deferred", "immediate", "exclusive":
c.txBegin = "BEGIN " + n.txlock
default:
return nil, fmt.Errorf("sqlite3: invalid _txlock: %s", n.txlock)
}
if !n.pragmas {
err = c.Conn.Exec(`PRAGMA busy_timeout=60000`)
if err != nil {
@@ -155,7 +195,7 @@ func (n *connector) Connect(ctx context.Context) (_ driver.Conn, err error) {
return nil, err
}
}
return &c, nil
return c, nil
}
type conn struct {
@@ -163,6 +203,8 @@ type conn struct {
txBegin string
txCommit string
txRollback string
tmRead sqlite3.TimeFormat
tmWrite sqlite3.TimeFormat
readOnly byte
}
@@ -250,7 +292,7 @@ func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
s.Close()
return nil, util.TailErr
}
return &stmt{s}, nil
return &stmt{Stmt: s, tmRead: c.tmRead, tmWrite: c.tmWrite}, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
@@ -261,7 +303,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
if savept, ok := ctx.(*saveptCtx); ok {
// Called from driver.Savepoint.
savept.Savepoint = c.Savepoint()
savept.Savepoint = c.Conn.Savepoint()
return resultRowsAffected(0), nil
}
@@ -282,6 +324,8 @@ func (*conn) CheckNamedValue(arg *driver.NamedValue) error {
type stmt struct {
*sqlite3.Stmt
tmWrite sqlite3.TimeFormat
tmRead sqlite3.TimeFormat
}
var (
@@ -333,7 +377,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
if err != nil {
return nil, err
}
return &rows{ctx, s.Stmt}, nil
return &rows{s, ctx}, nil
}
func (s *stmt) setupBindings(args []driver.NamedValue) error {
@@ -372,7 +416,7 @@ func (s *stmt) setupBindings(args []driver.NamedValue) error {
case sqlite3.ZeroBlob:
err = s.Stmt.BindZeroBlob(id, int64(a))
case time.Time:
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
err = s.Stmt.BindTime(id, a, s.tmWrite)
case interface{ Pointer() any }:
err = s.Stmt.BindPointer(id, a.Pointer())
case interface{ JSON() any }:
@@ -435,8 +479,8 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
}
type rows struct {
ctx context.Context
Stmt *sqlite3.Stmt
*stmt
ctx context.Context
}
func (r *rows) Close() error {
@@ -475,6 +519,10 @@ func (r *rows) Next(dest []driver.Value) error {
}
for i := range dest {
if t, ok := r.decodeTime(i); ok {
dest[i] = t
continue
}
switch r.Stmt.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.Stmt.ColumnInt64(i)
@@ -493,3 +541,22 @@ func (r *rows) Next(dest []driver.Value) error {
return r.Stmt.Err()
}
func (s *stmt) decodeTime(i int) (_ time.Time, _ bool) {
if s.tmRead == "" {
return
}
switch s.Stmt.ColumnType(i) {
case sqlite3.INTEGER, sqlite3.FLOAT, sqlite3.TEXT:
// maybe
default:
return
}
switch strings.ToUpper(s.Stmt.ColumnDeclType(i)) {
case "DATE", "TIME", "DATETIME", "TIMESTAMP":
// maybe
default:
return
}
return s.Stmt.ColumnTime(i, s.tmRead), s.Stmt.Err() == nil
}

View File

@@ -114,13 +114,7 @@ func Test_Open_txLock(t *testing.T) {
func Test_Open_txLock_invalid(t *testing.T) {
t.Parallel()
db, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err != nil {
t.Fatal(err)
}
defer db.Close()
_, err = db.Conn(context.TODO())
_, err := sql.Open("sqlite3", "file::memory:?_txlock=xclusive")
if err == nil {
t.Fatal("want error")
}

View File

@@ -20,7 +20,7 @@ type TimeFormat string
// TimeFormats recognized by SQLite to encode/decode time values.
//
// https://sqlite.org/lang_datefunc.html
// https://sqlite.org/lang_datefunc.html#time_values
const (
TimeFormatDefault TimeFormat = "" // time.RFC3339Nano