diff --git a/conn.go b/conn.go index 1ff2f1c..0ee1152 100644 --- a/conn.go +++ b/conn.go @@ -12,11 +12,12 @@ type Conn struct { ctx context.Context api sqliteAPI mem memory - arena arena handle uint32 - waiter chan struct{} - done <-chan struct{} + arena arena + pending *Stmt + waiter chan struct{} + done <-chan struct{} } // Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE]. @@ -108,32 +109,44 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) { c.waiter = nil } + // Finalize the uncompleted SQL statement. + if c.pending != nil { + c.pending.Close() + c.pending = nil + } + old = c.done c.done = done if done == nil { return old } + // Creating an uncompleted SQL statement prevents SQLite from ignoring + // an interrupt that comes before any other statements are started. + c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`) + c.pending.Step() + waiter := make(chan struct{}) c.waiter = waiter go func() { select { - case <-waiter: - // Waiter was cancelled. - case <-done: - // Done was closed. + case <-waiter: // Waiter was cancelled. + break - // Because it doesn't touch the C stack, - // sqlite3_interrupt is safe to call from a goroutine. + case <-done: // Done was closed. + + // This is safe to call from a goroutine + // because it doesn't touch the C stack. _, err := c.api.interrupt.Call(c.ctx, uint64(c.handle)) if err != nil { panic(err) } // Wait for the next call to SetInterrupt. - <-waiter // Waiter was cancelled. + <-waiter } - // Signal that the waiter is finished. + + // Signal that the waiter has finished. waiter <- struct{}{} }() return old @@ -205,28 +218,26 @@ func (c *Conn) error(rc uint64, sql ...string) error { var r []uint64 - // sqlite3_errmsg is guaranteed to never change the value of the error code. + r, _ = c.api.errstr.Call(c.ctx, rc) + if r != nil { + err.str = c.mem.readString(uint32(r[0]), 512) + } + r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle)) if r != nil { err.msg = c.mem.readString(uint32(r[0]), 512) } if sql != nil { - // sqlite3_error_offset is guaranteed to never change the value of the error code. r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle)) if r != nil && r[0] != math.MaxUint32 { err.sql = sql[0][r[0]:] } } - r, _ = c.api.errstr.Call(c.ctx, rc) - if r != nil { - err.str = c.mem.readString(uint32(r[0]), 512) - } - - if err.msg == err.str { + switch err.msg { + case err.str, "not an error": err.msg = "" - } return &err } diff --git a/conn_test.go b/conn_test.go index 5690009..e7b8ded 100644 --- a/conn_test.go +++ b/conn_test.go @@ -2,11 +2,9 @@ package sqlite3 import ( "bytes" - "context" "errors" "math" "testing" - "time" ) func TestConn_Close(t *testing.T) { @@ -45,9 +43,7 @@ func TestConn_Close_BUSY(t *testing.T) { } } -func TestConn_Interrupt(t *testing.T) { - t.Parallel() - +func TestConn_SetInterrupt(t *testing.T) { db, err := Open(":memory:") if err != nil { t.Fatal(err) @@ -61,7 +57,7 @@ func TestConn_Interrupt(t *testing.T) { SELECT 0, 1 UNION ALL SELECT next, curr + next FROM fibonacci - LIMIT 10e6 + LIMIT 1e6 ) SELECT min(curr) FROM fibonacci `) @@ -70,29 +66,47 @@ func TestConn_Interrupt(t *testing.T) { } defer stmt.Close() - ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) - db.SetInterrupt(ctx.Done()) - defer cancel() + done := make(chan struct{}) + close(done) + db.SetInterrupt(done) - for stmt.Step() { - } - - err = stmt.Err() - if err == nil { - t.Fatal("want error") - } var serr *Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) + + // Interrupting works. + err = stmt.Exec() + if err != nil { + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != INTERRUPT { + t.Errorf("got %d, want sqlite3.INTERRUPT", rc) + } + if got := err.Error(); got != `sqlite3: interrupted` { + t.Error("got message: ", got) + } } - if rc := serr.Code(); rc != INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) + + // Interrupting sticks. + err = db.Exec(`SELECT 1`) + if err != nil { + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != INTERRUPT { + t.Errorf("got %d, want sqlite3.INTERRUPT", rc) + } + if got := err.Error(); got != `sqlite3: interrupted` { + t.Error("got message: ", got) + } } db.SetInterrupt(nil) + + // Interrupting can be cleared. + err = db.Exec(`SELECT 1`) + if err != nil { + t.Fatal(err) + } } func TestConn_Prepare_Empty(t *testing.T) {