From dedec8682b9d965f878466d7b855b4099a24e52f Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 30 May 2023 13:39:34 +0100 Subject: [PATCH] Driver improvements. --- conn.go | 12 ++-- driver/driver.go | 178 ++++++++++++++++++++++------------------------- 2 files changed, 92 insertions(+), 98 deletions(-) diff --git a/conn.go b/conn.go index 6be6590..be10c71 100644 --- a/conn.go +++ b/conn.go @@ -325,17 +325,21 @@ func (c *Conn) error(rc uint64, sql ...string) error { // DriverConn is implemented by the SQLite [database/sql] driver connection. // // It can be used to access advanced SQLite features like -// [savepoints] and [incremental BLOB I/O]. +// [savepoints], [online backup] and [incremental BLOB I/O]. // // [savepoints]: https://www.sqlite.org/lang_savepoint.html +// [online backup]: https://www.sqlite.org/backup.html // [incremental BLOB I/O]: https://www.sqlite.org/c3ref/blob_open.html type DriverConn interface { + driver.Conn driver.ConnBeginTx driver.ExecerContext driver.ConnPrepareContext - Savepoint() Savepoint - OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) - SetInterrupt(ctx context.Context) (old context.Context) + + Savepoint() Savepoint + Backup(srcDB, dstURI string) error + Restore(dstDB, srcURI string) error + OpenBlob(db, table, column string, row int64, write bool) (*Blob, error) } diff --git a/driver/driver.go b/driver/driver.go index a3c2bd6..3d85214 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -49,7 +49,7 @@ type sqlite struct{} func (sqlite) Open(name string) (_ driver.Conn, err error) { var c conn - c.conn, err = sqlite3.Open(name) + c.Conn, err = sqlite3.Open(name) if err != nil { return nil, err } @@ -73,7 +73,7 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) { } } if len(pragmas) == 0 { - err := c.conn.Exec(` + err := c.Conn.Exec(` PRAGMA busy_timeout=60000; PRAGMA locking_mode=normal; `) @@ -83,7 +83,7 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) { } c.reusable = true } else { - s, _, err := c.conn.Prepare(` + s, _, err := c.Conn.Prepare(` SELECT * FROM PRAGMA_locking_mode, PRAGMA_query_only; @@ -102,11 +102,11 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) { return nil, err } } - return c, nil + return &c, nil } type conn struct { - conn *sqlite3.Conn + *sqlite3.Conn txBegin string txCommit string txRollback string @@ -116,25 +116,21 @@ type conn struct { var ( // Ensure these interfaces are implemented: - _ driver.ExecerContext = conn{} - _ driver.ConnBeginTx = conn{} - _ driver.Validator = conn{} - _ sqlite3.DriverConn = conn{} + _ driver.ExecerContext = &conn{} + _ driver.ConnBeginTx = &conn{} + _ driver.Validator = &conn{} + _ sqlite3.DriverConn = &conn{} ) -func (c conn) Close() error { - return c.conn.Close() -} - -func (c conn) IsValid() bool { +func (c *conn) IsValid() bool { return c.reusable } -func (c conn) Begin() (driver.Tx, error) { +func (c *conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } -func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) { +func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { txBegin := c.txBegin c.txCommit = `COMMIT` c.txRollback = `ROLLBACK` @@ -158,33 +154,43 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro break } - err := c.conn.Exec(txBegin) + old := c.Conn.SetInterrupt(ctx) + defer c.Conn.SetInterrupt(old) + + err := c.Conn.Exec(txBegin) if err != nil { return nil, err } return c, nil } -func (c conn) Commit() error { - err := c.conn.Exec(c.txCommit) - if err != nil && !c.conn.GetAutocommit() { +func (c *conn) Commit() error { + err := c.Conn.Exec(c.txCommit) + if err != nil && !c.GetAutocommit() { c.Rollback() } return err } -func (c conn) Rollback() error { - return c.conn.Exec(c.txRollback) +func (c *conn) Rollback() error { + return c.Conn.Exec(c.txRollback) } -func (c conn) Prepare(query string) (driver.Stmt, error) { - s, tail, err := c.conn.Prepare(query) +func (c *conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + old := c.Conn.SetInterrupt(ctx) + defer c.Conn.SetInterrupt(old) + + s, tail, err := c.Conn.Prepare(query) if err != nil { return nil, err } if tail != "" { // Check if the tail contains any SQL. - st, _, err := c.conn.Prepare(tail) + st, _, err := c.Conn.Prepare(tail) if err != nil { s.Close() return nil, err @@ -195,62 +201,46 @@ func (c conn) Prepare(query string) (driver.Stmt, error) { return nil, util.TailErr } } - return stmt{s, c.conn}, nil + return &stmt{s, c.Conn}, nil } -func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { - return c.Prepare(query) -} - -func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { if len(args) != 0 { // Slow path. return nil, driver.ErrSkip } - old := c.conn.SetInterrupt(ctx) - defer c.conn.SetInterrupt(old) + old := c.Conn.SetInterrupt(ctx) + defer c.Conn.SetInterrupt(old) - err := c.conn.Exec(query) + err := c.Conn.Exec(query) if err != nil { return nil, err } - return newResult(c.conn), nil -} - -func (c conn) Savepoint() sqlite3.Savepoint { - return c.conn.Savepoint() -} - -func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) { - return c.conn.OpenBlob(db, table, column, row, write) -} - -func (c conn) SetInterrupt(ctx context.Context) (old context.Context) { - return c.conn.SetInterrupt(ctx) + return newResult(c.Conn), nil } type stmt struct { - stmt *sqlite3.Stmt - conn *sqlite3.Conn + Stmt *sqlite3.Stmt + Conn *sqlite3.Conn } var ( // Ensure these interfaces are implemented: - _ driver.StmtExecContext = stmt{} - _ driver.StmtQueryContext = stmt{} - _ driver.NamedValueChecker = stmt{} + _ driver.StmtExecContext = &stmt{} + _ driver.StmtQueryContext = &stmt{} + _ driver.NamedValueChecker = &stmt{} ) -func (s stmt) Close() error { - return s.stmt.Close() +func (s *stmt) Close() error { + return s.Stmt.Close() } -func (s stmt) NumInput() int { - n := s.stmt.BindCount() +func (s *stmt) NumInput() int { + n := s.Stmt.BindCount() for i := 1; i <= n; i++ { - if s.stmt.BindName(i) != "" { + if s.Stmt.BindName(i) != "" { return -1 } } @@ -258,16 +248,16 @@ func (s stmt) NumInput() int { } // Deprecated: use ExecContext instead. -func (s stmt) Exec(args []driver.Value) (driver.Result, error) { +func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { return s.ExecContext(context.Background(), namedValues(args)) } // Deprecated: use QueryContext instead. -func (s stmt) Query(args []driver.Value) (driver.Rows, error) { +func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { return s.QueryContext(context.Background(), namedValues(args)) } -func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { +func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { // Use QueryContext to setup bindings. // No need to close rows: that simply resets the statement, exec does the same. _, err := s.QueryContext(ctx, args) @@ -275,16 +265,16 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver return nil, err } - err = s.stmt.Exec() + err = s.Stmt.Exec() if err != nil { return nil, err } - return newResult(s.conn), nil + return newResult(s.Conn), nil } -func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - err := s.stmt.ClearBindings() +func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + err := s.Stmt.ClearBindings() if err != nil { return nil, err } @@ -296,7 +286,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive ids = append(ids, arg.Ordinal) } else { for _, prefix := range []string{":", "@", "$"} { - if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 { + if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 { ids = append(ids, id) } } @@ -305,23 +295,23 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive for _, id := range ids { switch a := arg.Value.(type) { case bool: - err = s.stmt.BindBool(id, a) + err = s.Stmt.BindBool(id, a) case int: - err = s.stmt.BindInt(id, a) + err = s.Stmt.BindInt(id, a) case int64: - err = s.stmt.BindInt64(id, a) + err = s.Stmt.BindInt64(id, a) case float64: - err = s.stmt.BindFloat(id, a) + err = s.Stmt.BindFloat(id, a) case string: - err = s.stmt.BindText(id, a) + err = s.Stmt.BindText(id, a) case []byte: - err = s.stmt.BindBlob(id, a) + err = s.Stmt.BindBlob(id, a) case sqlite3.ZeroBlob: - err = s.stmt.BindZeroBlob(id, int64(a)) + err = s.Stmt.BindZeroBlob(id, int64(a)) case time.Time: - err = s.stmt.BindTime(id, a, sqlite3.TimeFormatDefault) + err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault) case nil: - err = s.stmt.BindNull(id) + err = s.Stmt.BindNull(id) default: panic(util.AssertErr()) } @@ -331,10 +321,10 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive } } - return rows{ctx, s.stmt, s.conn}, nil + return &rows{ctx, s.Stmt, s.Conn}, nil } -func (s stmt) CheckNamedValue(arg *driver.NamedValue) error { +func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error { switch arg.Value.(type) { case bool, int, int64, float64, string, []byte, sqlite3.ZeroBlob, time.Time, nil: @@ -377,44 +367,44 @@ func (r resultRowsAffected) RowsAffected() (int64, error) { type rows struct { ctx context.Context - stmt *sqlite3.Stmt - conn *sqlite3.Conn + Stmt *sqlite3.Stmt + Conn *sqlite3.Conn } -func (r rows) Close() error { - return r.stmt.Reset() +func (r *rows) Close() error { + return r.Stmt.Reset() } -func (r rows) Columns() []string { - count := r.stmt.ColumnCount() +func (r *rows) Columns() []string { + count := r.Stmt.ColumnCount() columns := make([]string, count) for i := range columns { - columns[i] = r.stmt.ColumnName(i) + columns[i] = r.Stmt.ColumnName(i) } return columns } -func (r rows) Next(dest []driver.Value) error { - old := r.conn.SetInterrupt(r.ctx) - defer r.conn.SetInterrupt(old) +func (r *rows) Next(dest []driver.Value) error { + old := r.Conn.SetInterrupt(r.ctx) + defer r.Conn.SetInterrupt(old) - if !r.stmt.Step() { - if err := r.stmt.Err(); err != nil { + if !r.Stmt.Step() { + if err := r.Stmt.Err(); err != nil { return err } return io.EOF } for i := range dest { - switch r.stmt.ColumnType(i) { + switch r.Stmt.ColumnType(i) { case sqlite3.INTEGER: - dest[i] = r.stmt.ColumnInt64(i) + dest[i] = r.Stmt.ColumnInt64(i) case sqlite3.FLOAT: - dest[i] = r.stmt.ColumnFloat(i) + dest[i] = r.Stmt.ColumnFloat(i) case sqlite3.BLOB: - dest[i] = r.stmt.ColumnRawBlob(i) + dest[i] = r.Stmt.ColumnRawBlob(i) case sqlite3.TEXT: - dest[i] = stringOrTime(r.stmt.ColumnRawText(i)) + dest[i] = stringOrTime(r.Stmt.ColumnRawText(i)) case sqlite3.NULL: if buf, ok := dest[i].([]byte); ok { dest[i] = buf[0:0] @@ -426,5 +416,5 @@ func (r rows) Next(dest []driver.Value) error { } } - return r.stmt.Err() + return r.Stmt.Err() }