mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Context cancellation.
This commit is contained in:
111
driver/driver.go
111
driver/driver.go
@@ -2,6 +2,7 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
@@ -40,7 +41,6 @@ var (
|
||||
_ driver.Validator = conn{}
|
||||
// _ driver.SessionResetter = conn{}
|
||||
// _ driver.ExecerContext = conn{}
|
||||
// _ driver.QueryerContext = conn{}
|
||||
// _ driver.ConnBeginTx = conn{}
|
||||
)
|
||||
|
||||
@@ -101,9 +101,8 @@ type stmt struct {
|
||||
|
||||
var (
|
||||
// Ensure these interfaces are implemented:
|
||||
// _ driver.StmtExecContext = stmt{}
|
||||
// _ driver.StmtQueryContext = stmt{}
|
||||
_ = stmt{}
|
||||
_ driver.StmtExecContext = stmt{}
|
||||
_ driver.StmtQueryContext = stmt{}
|
||||
)
|
||||
|
||||
func (s stmt) Close() error {
|
||||
@@ -114,8 +113,18 @@ func (s stmt) NumInput() int {
|
||||
return s.stmt.BindCount()
|
||||
}
|
||||
|
||||
// Deprecated: use ExecContext instead.
|
||||
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
_, err := s.Query(args)
|
||||
return s.ExecContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
// Deprecated: use QueryContext instead.
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
return s.QueryContext(context.Background(), namedValues(args))
|
||||
}
|
||||
|
||||
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
_, err := s.QueryContext(ctx, args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -131,32 +140,51 @@ func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
var err error
|
||||
for i, arg := range args {
|
||||
switch a := arg.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(i+1, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(i+1, a)
|
||||
case float64:
|
||||
err = s.stmt.BindFloat(i+1, a)
|
||||
case string:
|
||||
err = s.stmt.BindText(i+1, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(i+1, a)
|
||||
case time.Time:
|
||||
err = s.stmt.BindText(i+1, a.Format(time.RFC3339Nano))
|
||||
case nil:
|
||||
err = s.stmt.BindNull(i + 1)
|
||||
default:
|
||||
panic(assertErr)
|
||||
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
err := s.stmt.ClearBindings()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ids [3]int
|
||||
for _, arg := range args {
|
||||
ids := ids[:0]
|
||||
if arg.Name == "" {
|
||||
ids = append(ids, arg.Ordinal)
|
||||
} else {
|
||||
for _, prefix := range []string{":", "@", "$"} {
|
||||
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, id := range ids {
|
||||
switch a := arg.Value.(type) {
|
||||
case bool:
|
||||
err = s.stmt.BindBool(id, a)
|
||||
case int64:
|
||||
err = s.stmt.BindInt64(id, a)
|
||||
case float64:
|
||||
err = s.stmt.BindFloat(id, a)
|
||||
case string:
|
||||
err = s.stmt.BindText(id, a)
|
||||
case []byte:
|
||||
err = s.stmt.BindBlob(id, a)
|
||||
case time.Time:
|
||||
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
|
||||
case nil:
|
||||
err = s.stmt.BindNull(id)
|
||||
default:
|
||||
panic(assertErr)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return rows{s.stmt}, nil
|
||||
|
||||
return rows{ctx, s.stmt, s.conn}, nil
|
||||
}
|
||||
|
||||
type result struct{ lastInsertId, rowsAffected int64 }
|
||||
@@ -169,40 +197,47 @@ func (r result) RowsAffected() (int64, error) {
|
||||
return r.rowsAffected, nil
|
||||
}
|
||||
|
||||
type rows struct{ s *sqlite3.Stmt }
|
||||
type rows struct {
|
||||
ctx context.Context
|
||||
stmt *sqlite3.Stmt
|
||||
conn *sqlite3.Conn
|
||||
}
|
||||
|
||||
func (r rows) Close() error {
|
||||
return r.s.Reset()
|
||||
return r.stmt.Reset()
|
||||
}
|
||||
|
||||
func (r rows) Columns() []string {
|
||||
count := r.s.ColumnCount()
|
||||
count := r.stmt.ColumnCount()
|
||||
columns := make([]string, count)
|
||||
for i := range columns {
|
||||
columns[i] = r.s.ColumnName(i)
|
||||
columns[i] = r.stmt.ColumnName(i)
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
func (r rows) Next(dest []driver.Value) error {
|
||||
if !r.s.Step() {
|
||||
if err := r.s.Err(); err != nil {
|
||||
ch := r.conn.SetInterrupt(r.ctx.Done())
|
||||
defer r.conn.SetInterrupt(ch)
|
||||
|
||||
if !r.stmt.Step() {
|
||||
if err := r.stmt.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
for i := range dest {
|
||||
switch r.s.ColumnType(i) {
|
||||
switch r.stmt.ColumnType(i) {
|
||||
case sqlite3.INTEGER:
|
||||
dest[i] = r.s.ColumnInt64(i)
|
||||
dest[i] = r.stmt.ColumnInt64(i)
|
||||
case sqlite3.FLOAT:
|
||||
dest[i] = r.s.ColumnFloat(i)
|
||||
dest[i] = r.stmt.ColumnFloat(i)
|
||||
case sqlite3.TEXT:
|
||||
dest[i] = maybeDate(r.s.ColumnText(i))
|
||||
dest[i] = maybeDate(r.stmt.ColumnText(i))
|
||||
case sqlite3.BLOB:
|
||||
buf, _ := dest[i].([]byte)
|
||||
dest[i] = r.s.ColumnBlob(i, buf)
|
||||
dest[i] = r.stmt.ColumnBlob(i, buf)
|
||||
case sqlite3.NULL:
|
||||
if buf, ok := dest[i].([]byte); ok {
|
||||
dest[i] = buf[0:0]
|
||||
@@ -214,5 +249,5 @@ func (r rows) Next(dest []driver.Value) error {
|
||||
}
|
||||
}
|
||||
|
||||
return r.s.Err()
|
||||
return r.stmt.Err()
|
||||
}
|
||||
|
||||
14
driver/util.go
Normal file
14
driver/util.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package driver
|
||||
|
||||
import "database/sql/driver"
|
||||
|
||||
func namedValues(args []driver.Value) []driver.NamedValue {
|
||||
named := make([]driver.NamedValue, len(args))
|
||||
for i, v := range args {
|
||||
named[i] = driver.NamedValue{
|
||||
Ordinal: i + 1,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return named
|
||||
}
|
||||
18
driver/util_test.go
Normal file
18
driver/util_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package driver
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_namedValues(t *testing.T) {
|
||||
want := []driver.NamedValue{
|
||||
{Ordinal: 1, Value: true},
|
||||
{Ordinal: 2, Value: false},
|
||||
}
|
||||
got := namedValues([]driver.Value{true, false})
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user