mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Driver improvements.
This commit is contained in:
178
driver/driver.go
178
driver/driver.go
@@ -49,7 +49,7 @@ type sqlite struct{}
|
||||
|
||||
func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
var c conn
|
||||
c.conn, err = sqlite3.Open(name)
|
||||
c.Conn, err = sqlite3.Open(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
}
|
||||
}
|
||||
if len(pragmas) == 0 {
|
||||
err := c.conn.Exec(`
|
||||
err := c.Conn.Exec(`
|
||||
PRAGMA busy_timeout=60000;
|
||||
PRAGMA locking_mode=normal;
|
||||
`)
|
||||
@@ -83,7 +83,7 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
}
|
||||
c.reusable = true
|
||||
} else {
|
||||
s, _, err := c.conn.Prepare(`
|
||||
s, _, err := c.Conn.Prepare(`
|
||||
SELECT * FROM
|
||||
PRAGMA_locking_mode,
|
||||
PRAGMA_query_only;
|
||||
@@ -102,11 +102,11 @@ func (sqlite) Open(name string) (_ driver.Conn, err error) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c, nil
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
conn *sqlite3.Conn
|
||||
*sqlite3.Conn
|
||||
txBegin string
|
||||
txCommit string
|
||||
txRollback string
|
||||
@@ -116,25 +116,21 @@ type conn struct {
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.ExecerContext = conn{}
|
||||
_ driver.ConnBeginTx = conn{}
|
||||
_ driver.Validator = conn{}
|
||||
_ sqlite3.DriverConn = conn{}
|
||||
_ driver.ExecerContext = &conn{}
|
||||
_ driver.ConnBeginTx = &conn{}
|
||||
_ driver.Validator = &conn{}
|
||||
_ sqlite3.DriverConn = &conn{}
|
||||
)
|
||||
|
||||
func (c conn) Close() error {
|
||||
return c.conn.Close()
|
||||
}
|
||||
|
||||
func (c conn) IsValid() bool {
|
||||
func (c *conn) IsValid() bool {
|
||||
return c.reusable
|
||||
}
|
||||
|
||||
func (c conn) Begin() (driver.Tx, error) {
|
||||
func (c *conn) Begin() (driver.Tx, error) {
|
||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
||||
}
|
||||
|
||||
func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
txBegin := c.txBegin
|
||||
c.txCommit = `COMMIT`
|
||||
c.txRollback = `ROLLBACK`
|
||||
@@ -158,33 +154,43 @@ func (c conn) BeginTx(_ context.Context, opts driver.TxOptions) (driver.Tx, erro
|
||||
break
|
||||
}
|
||||
|
||||
err := c.conn.Exec(txBegin)
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
err := c.Conn.Exec(txBegin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c conn) Commit() error {
|
||||
err := c.conn.Exec(c.txCommit)
|
||||
if err != nil && !c.conn.GetAutocommit() {
|
||||
func (c *conn) Commit() error {
|
||||
err := c.Conn.Exec(c.txCommit)
|
||||
if err != nil && !c.GetAutocommit() {
|
||||
c.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c conn) Rollback() error {
|
||||
return c.conn.Exec(c.txRollback)
|
||||
func (c *conn) Rollback() error {
|
||||
return c.Conn.Exec(c.txRollback)
|
||||
}
|
||||
|
||||
func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
s, tail, err := c.conn.Prepare(query)
|
||||
func (c *conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return c.PrepareContext(context.Background(), query)
|
||||
}
|
||||
|
||||
func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
s, tail, err := c.Conn.Prepare(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tail != "" {
|
||||
// Check if the tail contains any SQL.
|
||||
st, _, err := c.conn.Prepare(tail)
|
||||
st, _, err := c.Conn.Prepare(tail)
|
||||
if err != nil {
|
||||
s.Close()
|
||||
return nil, err
|
||||
@@ -195,62 +201,46 @@ func (c conn) Prepare(query string) (driver.Stmt, error) {
|
||||
return nil, util.TailErr
|
||||
}
|
||||
}
|
||||
return stmt{s, c.conn}, nil
|
||||
return &stmt{s, c.Conn}, nil
|
||||
}
|
||||
|
||||
func (c conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
|
||||
return c.Prepare(query)
|
||||
}
|
||||
|
||||
func (c conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
if len(args) != 0 {
|
||||
// Slow path.
|
||||
return nil, driver.ErrSkip
|
||||
}
|
||||
|
||||
old := c.conn.SetInterrupt(ctx)
|
||||
defer c.conn.SetInterrupt(old)
|
||||
old := c.Conn.SetInterrupt(ctx)
|
||||
defer c.Conn.SetInterrupt(old)
|
||||
|
||||
err := c.conn.Exec(query)
|
||||
err := c.Conn.Exec(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newResult(c.conn), nil
|
||||
}
|
||||
|
||||
func (c conn) Savepoint() sqlite3.Savepoint {
|
||||
return c.conn.Savepoint()
|
||||
}
|
||||
|
||||
func (c conn) OpenBlob(db, table, column string, row int64, write bool) (*sqlite3.Blob, error) {
|
||||
return c.conn.OpenBlob(db, table, column, row, write)
|
||||
}
|
||||
|
||||
func (c conn) SetInterrupt(ctx context.Context) (old context.Context) {
|
||||
return c.conn.SetInterrupt(ctx)
|
||||
return newResult(c.Conn), nil
|
||||
}
|
||||
|
||||
type stmt struct {
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
_ driver.StmtExecContext = stmt{}
|
||||
_ driver.StmtQueryContext = stmt{}
|
||||
_ driver.NamedValueChecker = stmt{}
|
||||
_ driver.StmtExecContext = &stmt{}
|
||||
_ driver.StmtQueryContext = &stmt{}
|
||||
_ driver.NamedValueChecker = &stmt{}
|
||||
)
|
||||
|
||||
func (s stmt) Close() error {
|
||||
return s.stmt.Close()
|
||||
func (s *stmt) Close() error {
|
||||
return s.Stmt.Close()
|
||||
}
|
||||
|
||||
func (s stmt) NumInput() int {
|
||||
n := s.stmt.BindCount()
|
||||
func (s *stmt) NumInput() int {
|
||||
n := s.Stmt.BindCount()
|
||||
for i := 1; i <= n; i++ {
|
||||
if s.stmt.BindName(i) != "" {
|
||||
if s.Stmt.BindName(i) != "" {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
@@ -258,16 +248,16 @@ func (s stmt) NumInput() int {
|
||||
}
|
||||
|
||||
// Deprecated: use ExecContext instead.
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
return s.ExecContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
// Deprecated: use QueryContext instead.
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
// Use QueryContext to setup bindings.
|
||||
// No need to close rows: that simply resets the statement, exec does the same.
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
@@ -275,16 +265,16 @@ func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = s.stmt.Exec()
|
||||
err = s.Stmt.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newResult(s.conn), nil
|
||||
return newResult(s.Conn), nil
|
||||
}
|
||||
|
||||
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.stmt.ClearBindings()
|
||||
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.Stmt.ClearBindings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -296,7 +286,7 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
ids = append(ids, arg.Ordinal)
|
||||
} else {
|
||||
for _, prefix := range []string{":", "@", "$"} {
|
||||
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
if id := s.Stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
@@ -305,23 +295,23 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
for _, id := range ids {
|
||||
switch a := arg.Value.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(id, a)
|
||||
err = s.Stmt.BindBool(id, a)
|
||||
case int:
|
||||
err = s.stmt.BindInt(id, a)
|
||||
err = s.Stmt.BindInt(id, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(id, a)
|
||||
err = s.Stmt.BindInt64(id, a)
|
||||
case float64:
|
||||
err = s.stmt.BindFloat(id, a)
|
||||
err = s.Stmt.BindFloat(id, a)
|
||||
case string:
|
||||
err = s.stmt.BindText(id, a)
|
||||
err = s.Stmt.BindText(id, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(id, a)
|
||||
err = s.Stmt.BindBlob(id, a)
|
||||
case sqlite3.ZeroBlob:
|
||||
err = s.stmt.BindZeroBlob(id, int64(a))
|
||||
err = s.Stmt.BindZeroBlob(id, int64(a))
|
||||
case time.Time:
|
||||
err = s.stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
err = s.Stmt.BindTime(id, a, sqlite3.TimeFormatDefault)
|
||||
case nil:
|
||||
err = s.stmt.BindNull(id)
|
||||
err = s.Stmt.BindNull(id)
|
||||
default:
|
||||
panic(util.AssertErr())
|
||||
}
|
||||
@@ -331,10 +321,10 @@ func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (drive
|
||||
}
|
||||
}
|
||||
|
||||
return rows{ctx, s.stmt, s.conn}, nil
|
||||
return &rows{ctx, s.Stmt, s.Conn}, nil
|
||||
}
|
||||
|
||||
func (s stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
func (s *stmt) CheckNamedValue(arg *driver.NamedValue) error {
|
||||
switch arg.Value.(type) {
|
||||
case bool, int, int64, float64, string, []byte,
|
||||
sqlite3.ZeroBlob, time.Time, nil:
|
||||
@@ -377,44 +367,44 @@ func (r resultRowsAffected) RowsAffected() (int64, error) {
|
||||
|
||||
type rows struct {
|
||||
ctx context.Context
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
Stmt *sqlite3.Stmt
|
||||
Conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
func (r rows) Close() error {
|
||||
return r.stmt.Reset()
|
||||
func (r *rows) Close() error {
|
||||
return r.Stmt.Reset()
|
||||
}
|
||||
|
||||
func (r rows) Columns() []string {
|
||||
count := r.stmt.ColumnCount()
|
||||
func (r *rows) Columns() []string {
|
||||
count := r.Stmt.ColumnCount()
|
||||
columns := make([]string, count)
|
||||
for i := range columns {
|
||||
columns[i] = r.stmt.ColumnName(i)
|
||||
columns[i] = r.Stmt.ColumnName(i)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (r rows) Next(dest []driver.Value) error {
|
||||
old := r.conn.SetInterrupt(r.ctx)
|
||||
defer r.conn.SetInterrupt(old)
|
||||
func (r *rows) Next(dest []driver.Value) error {
|
||||
old := r.Conn.SetInterrupt(r.ctx)
|
||||
defer r.Conn.SetInterrupt(old)
|
||||
|
||||
if !r.stmt.Step() {
|
||||
if err := r.stmt.Err(); err != nil {
|
||||
if !r.Stmt.Step() {
|
||||
if err := r.Stmt.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for i := range dest {
|
||||
switch r.stmt.ColumnType(i) {
|
||||
switch r.Stmt.ColumnType(i) {
|
||||
case sqlite3.INTEGER:
|
||||
dest[i] = r.stmt.ColumnInt64(i)
|
||||
dest[i] = r.Stmt.ColumnInt64(i)
|
||||
case sqlite3.FLOAT:
|
||||
dest[i] = r.stmt.ColumnFloat(i)
|
||||
dest[i] = r.Stmt.ColumnFloat(i)
|
||||
case sqlite3.BLOB:
|
||||
dest[i] = r.stmt.ColumnRawBlob(i)
|
||||
dest[i] = r.Stmt.ColumnRawBlob(i)
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = stringOrTime(r.stmt.ColumnRawText(i))
|
||||
dest[i] = stringOrTime(r.Stmt.ColumnRawText(i))
|
||||
case sqlite3.NULL:
|
||||
if buf, ok := dest[i].([]byte); ok {
|
||||
dest[i] = buf[0:0]
|
||||
@@ -426,5 +416,5 @@ func (r rows) Next(dest []driver.Value) error {
|
||||
}
|
||||
}
|
||||
|
||||
return r.stmt.Err()
|
||||
return r.Stmt.Err()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user