Context cancellation.

This commit is contained in:
Nuno Cruces
2023-02-18 02:16:11 +00:00
parent 4ac2ccf473
commit f50d5df3d0
3 changed files with 105 additions and 38 deletions

View File

@@ -2,6 +2,7 @@
package driver
import (
"context"
"database/sql"
"database/sql/driver"
"io"
@@ -40,7 +41,6 @@ var (
_ driver.Validator = conn{}
// _ driver.SessionResetter = conn{}
// _ driver.ExecerContext = conn{}
// _ driver.QueryerContext = conn{}
// _ driver.ConnBeginTx = conn{}
)
@@ -101,9 +101,8 @@ type stmt struct {
var (
// Ensure these interfaces are implemented:
// _ driver.StmtExecContext = stmt{}
// _ driver.StmtQueryContext = stmt{}
_ = stmt{}
_ driver.StmtExecContext = stmt{}
_ driver.StmtQueryContext = stmt{}
)
func (s stmt) Close() error {
@@ -114,8 +113,18 @@ func (s stmt) NumInput() int {
return s.stmt.BindCount()
}
// Deprecated: use ExecContext instead.
func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
_, err := s.Query(args)
return s.ExecContext(context.Background(), namedValues(args))
}
// Deprecated: use QueryContext instead.
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
return s.QueryContext(context.Background(), namedValues(args))
}
func (s stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
_, err := s.QueryContext(ctx, args)
if err != nil {
return nil, err
}
@@ -131,32 +140,51 @@ func (s stmt) Exec(args []driver.Value) (driver.Result, error) {
}, nil
}
func (s stmt) Query(args []driver.Value) (driver.Rows, error) {
var err error
for i, arg := range args {
switch a := arg.(type) {
case bool:
err = s.stmt.BindBool(i+1, a)
case int64:
err = s.stmt.BindInt64(i+1, a)
case float64:
err = s.stmt.BindFloat(i+1, a)
case string:
err = s.stmt.BindText(i+1, a)
case []byte:
err = s.stmt.BindBlob(i+1, a)
case time.Time:
err = s.stmt.BindText(i+1, a.Format(time.RFC3339Nano))
case nil:
err = s.stmt.BindNull(i + 1)
default:
panic(assertErr)
func (s stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
err := s.stmt.ClearBindings()
if err != nil {
return nil, err
}
var ids [3]int
for _, arg := range args {
ids := ids[:0]
if arg.Name == "" {
ids = append(ids, arg.Ordinal)
} else {
for _, prefix := range []string{":", "@", "$"} {
if id := s.stmt.BindIndex(prefix + arg.Name); id != 0 {
ids = append(ids, id)
}
}
}
for _, id := range ids {
switch a := arg.Value.(type) {
case bool:
err = s.stmt.BindBool(id, a)
case int64:
err = s.stmt.BindInt64(id, a)
case float64:
err = s.stmt.BindFloat(id, a)
case string:
err = s.stmt.BindText(id, a)
case []byte:
err = s.stmt.BindBlob(id, a)
case time.Time:
err = s.stmt.BindText(id, a.Format(time.RFC3339Nano))
case nil:
err = s.stmt.BindNull(id)
default:
panic(assertErr)
}
}
if err != nil {
return nil, err
}
}
return rows{s.stmt}, nil
return rows{ctx, s.stmt, s.conn}, nil
}
type result struct{ lastInsertId, rowsAffected int64 }
@@ -169,40 +197,47 @@ func (r result) RowsAffected() (int64, error) {
return r.rowsAffected, nil
}
type rows struct{ s *sqlite3.Stmt }
type rows struct {
ctx context.Context
stmt *sqlite3.Stmt
conn *sqlite3.Conn
}
func (r rows) Close() error {
return r.s.Reset()
return r.stmt.Reset()
}
func (r rows) Columns() []string {
count := r.s.ColumnCount()
count := r.stmt.ColumnCount()
columns := make([]string, count)
for i := range columns {
columns[i] = r.s.ColumnName(i)
columns[i] = r.stmt.ColumnName(i)
}
return columns
}
func (r rows) Next(dest []driver.Value) error {
if !r.s.Step() {
if err := r.s.Err(); err != nil {
ch := r.conn.SetInterrupt(r.ctx.Done())
defer r.conn.SetInterrupt(ch)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {
return err
}
return io.EOF
}
for i := range dest {
switch r.s.ColumnType(i) {
switch r.stmt.ColumnType(i) {
case sqlite3.INTEGER:
dest[i] = r.s.ColumnInt64(i)
dest[i] = r.stmt.ColumnInt64(i)
case sqlite3.FLOAT:
dest[i] = r.s.ColumnFloat(i)
dest[i] = r.stmt.ColumnFloat(i)
case sqlite3.TEXT:
dest[i] = maybeDate(r.s.ColumnText(i))
dest[i] = maybeDate(r.stmt.ColumnText(i))
case sqlite3.BLOB:
buf, _ := dest[i].([]byte)
dest[i] = r.s.ColumnBlob(i, buf)
dest[i] = r.stmt.ColumnBlob(i, buf)
case sqlite3.NULL:
if buf, ok := dest[i].([]byte); ok {
dest[i] = buf[0:0]
@@ -214,5 +249,5 @@ func (r rows) Next(dest []driver.Value) error {
}
}
return r.s.Err()
return r.stmt.Err()
}

14
driver/util.go Normal file
View File

@@ -0,0 +1,14 @@
package driver
import "database/sql/driver"
func namedValues(args []driver.Value) []driver.NamedValue {
named := make([]driver.NamedValue, len(args))
for i, v := range args {
named[i] = driver.NamedValue{
Ordinal: i + 1,
Value: v,
}
}
return named
}

18
driver/util_test.go Normal file
View File

@@ -0,0 +1,18 @@
package driver
import (
"database/sql/driver"
"reflect"
"testing"
)
func Test_namedValues(t *testing.T) {
want := []driver.NamedValue{
{Ordinal: 1, Value: true},
{Ordinal: 2, Value: false},
}
got := namedValues([]driver.Value{true, false})
if !reflect.DeepEqual(got, want) {
t.Errorf("got %v, want %v", got, want)
}
}