Interrupt API.

This commit is contained in:
Nuno Cruces
2023-02-24 14:56:49 +00:00
parent 0146496036
commit 8c28c3a6f4
5 changed files with 32 additions and 31 deletions

44
conn.go
View File

@@ -14,10 +14,10 @@ type Conn struct {
mem memory
handle uint32
arena arena
pending *Stmt
waiter chan struct{}
done <-chan struct{}
arena arena
interrupt context.Context
waiter chan struct{}
pending *Stmt
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
@@ -76,7 +76,7 @@ func (c *Conn) Close() error {
return nil
}
c.SetInterrupt(nil)
c.SetInterrupt(context.Background())
r, err := c.api.close.Call(c.ctx, uint64(c.handle))
if err != nil {
@@ -102,19 +102,21 @@ func (c *Conn) GetAutocommit() bool {
return r[0] != 0
}
// SetInterrupt interrupts a long-running query when done is closed.
// SetInterrupt interrupts a long-running query when a context is done.
//
// Subsequent uses of the connection will return [INTERRUPT]
// until done is reset by another call to SetInterrupt.
// until the context is reset by another call to SetInterrupt.
//
// Typically, done is provided by [context.Context.Done]:
// For example, a timeout can be associated with a connection:
//
// ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
// conn.SetInterrupt(ctx.Done())
// conn.SetInterrupt(ctx)
// defer cancel()
//
// SetInterrupt returns the old context assigned to the connection.
//
// https://www.sqlite.org/c3ref/interrupt.html
func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) {
// Is a waiter running?
if c.waiter != nil {
c.waiter <- struct{}{} // Cancel the waiter.
@@ -122,9 +124,9 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
c.waiter = nil
}
old = c.done
c.done = done
if done == nil {
old = c.interrupt
c.interrupt = ctx
if ctx == nil || ctx == context.Background() || ctx == context.TODO() || ctx.Done() == nil {
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
@@ -155,7 +157,7 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
case <-waiter: // Waiter was cancelled.
break
case <-done: // Done was closed.
case <-ctx.Done(): // Done was closed.
// This is safe to call from a goroutine
// because it doesn't touch the C stack.
@@ -175,16 +177,14 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
}
func (c *Conn) checkInterrupt() bool {
select {
case <-c.done: // Done was closed.
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return true
default:
if c.interrupt == nil || c.interrupt.Err() == nil {
return false
}
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return true
}
// Exec is a convenience function that allows an application to run

View File

@@ -147,8 +147,8 @@ func (c conn) ExecContext(ctx context.Context, query string, args []driver.Named
return nil, driver.ErrSkip
}
ch := c.conn.SetInterrupt(ctx.Done())
defer c.conn.SetInterrupt(ch)
old := c.conn.SetInterrupt(ctx)
defer c.conn.SetInterrupt(old)
err := c.conn.Exec(query)
if err != nil {
@@ -307,8 +307,8 @@ func (r rows) Columns() []string {
}
func (r rows) Next(dest []driver.Value) error {
ch := r.conn.SetInterrupt(r.ctx.Done())
defer r.conn.SetInterrupt(ch)
old := r.conn.SetInterrupt(r.ctx)
defer r.conn.SetInterrupt(old)
if !r.stmt.Step() {
if err := r.stmt.Err(); err != nil {

View File

@@ -1,6 +1,7 @@
package sqlite3
import (
"context"
"fmt"
"runtime"
)
@@ -61,7 +62,7 @@ func (conn *Conn) Savepoint() (release func(*error)) {
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := conn.SetInterrupt(nil)
old := conn.SetInterrupt(context.Background())
defer conn.SetInterrupt(old)
err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))

View File

@@ -77,7 +77,7 @@ func TestConn_SetInterrupt(t *testing.T) {
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx.Done())
db.SetInterrupt(ctx)
// Interrupt doesn't interrupt this.
err = db.Exec(`SELECT 1`)
@@ -103,7 +103,7 @@ func TestConn_SetInterrupt(t *testing.T) {
}
defer stmt.Close()
db.SetInterrupt(ctx.Done())
db.SetInterrupt(ctx)
cancel()
var serr *sqlite3.Error
@@ -134,7 +134,7 @@ func TestConn_SetInterrupt(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
db.SetInterrupt(ctx.Done())
db.SetInterrupt(ctx)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)

View File

@@ -150,7 +150,7 @@ func TestConn_Savepoint_interrupt(t *testing.T) {
}
ctx, cancel := context.WithCancel(context.Background())
db.SetInterrupt(ctx.Done())
db.SetInterrupt(ctx)
release1 := db.Savepoint()
err = db.Exec(`INSERT INTO test(col) VALUES(2)`)