Fix flakiness.

This commit is contained in:
Nuno Cruces
2023-02-16 13:30:31 +00:00
parent f85426022d
commit 110f36bdf9
2 changed files with 68 additions and 43 deletions

51
conn.go
View File

@@ -12,11 +12,12 @@ type Conn struct {
ctx context.Context
api sqliteAPI
mem memory
arena arena
handle uint32
waiter chan struct{}
done <-chan struct{}
arena arena
pending *Stmt
waiter chan struct{}
done <-chan struct{}
}
// Open calls [OpenFlags] with [OPEN_READWRITE] and [OPEN_CREATE].
@@ -108,32 +109,44 @@ func (c *Conn) SetInterrupt(done <-chan struct{}) (old <-chan struct{}) {
c.waiter = nil
}
// Finalize the uncompleted SQL statement.
if c.pending != nil {
c.pending.Close()
c.pending = nil
}
old = c.done
c.done = done
if done == nil {
return old
}
// Creating an uncompleted SQL statement prevents SQLite from ignoring
// an interrupt that comes before any other statements are started.
c.pending, _, _ = c.Prepare(`SELECT 1 UNION ALL SELECT 2`)
c.pending.Step()
waiter := make(chan struct{})
c.waiter = waiter
go func() {
select {
case <-waiter:
// Waiter was cancelled.
case <-done:
// Done was closed.
case <-waiter: // Waiter was cancelled.
break
// Because it doesn't touch the C stack,
// sqlite3_interrupt is safe to call from a goroutine.
case <-done: // Done was closed.
// This is safe to call from a goroutine
// because it doesn't touch the C stack.
_, err := c.api.interrupt.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
// Wait for the next call to SetInterrupt.
<-waiter // Waiter was cancelled.
<-waiter
}
// Signal that the waiter is finished.
// Signal that the waiter has finished.
waiter <- struct{}{}
}()
return old
@@ -205,28 +218,26 @@ func (c *Conn) error(rc uint64, sql ...string) error {
var r []uint64
// sqlite3_errmsg is guaranteed to never change the value of the error code.
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), 512)
}
r, _ = c.api.errmsg.Call(c.ctx, uint64(c.handle))
if r != nil {
err.msg = c.mem.readString(uint32(r[0]), 512)
}
if sql != nil {
// sqlite3_error_offset is guaranteed to never change the value of the error code.
r, _ = c.api.erroff.Call(c.ctx, uint64(c.handle))
if r != nil && r[0] != math.MaxUint32 {
err.sql = sql[0][r[0]:]
}
}
r, _ = c.api.errstr.Call(c.ctx, rc)
if r != nil {
err.str = c.mem.readString(uint32(r[0]), 512)
}
if err.msg == err.str {
switch err.msg {
case err.str, "not an error":
err.msg = ""
}
return &err
}

View File

@@ -2,11 +2,9 @@ package sqlite3
import (
"bytes"
"context"
"errors"
"math"
"testing"
"time"
)
func TestConn_Close(t *testing.T) {
@@ -45,9 +43,7 @@ func TestConn_Close_BUSY(t *testing.T) {
}
}
func TestConn_Interrupt(t *testing.T) {
t.Parallel()
func TestConn_SetInterrupt(t *testing.T) {
db, err := Open(":memory:")
if err != nil {
t.Fatal(err)
@@ -61,7 +57,7 @@ func TestConn_Interrupt(t *testing.T) {
SELECT 0, 1
UNION ALL
SELECT next, curr + next FROM fibonacci
LIMIT 10e6
LIMIT 1e6
)
SELECT min(curr) FROM fibonacci
`)
@@ -70,29 +66,47 @@ func TestConn_Interrupt(t *testing.T) {
}
defer stmt.Close()
ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
db.SetInterrupt(ctx.Done())
defer cancel()
done := make(chan struct{})
close(done)
db.SetInterrupt(done)
for stmt.Step() {
}
err = stmt.Err()
if err == nil {
t.Fatal("want error")
}
var serr *Error
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
// Interrupting works.
err = stmt.Exec()
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
// Interrupting sticks.
err = db.Exec(`SELECT 1`)
if err != nil {
if !errors.As(err, &serr) {
t.Fatalf("got %T, want sqlite3.Error", err)
}
if rc := serr.Code(); rc != INTERRUPT {
t.Errorf("got %d, want sqlite3.INTERRUPT", rc)
}
if got := err.Error(); got != `sqlite3: interrupted` {
t.Error("got message: ", got)
}
}
db.SetInterrupt(nil)
// Interrupting can be cleared.
err = db.Exec(`SELECT 1`)
if err != nil {
t.Fatal(err)
}
}
func TestConn_Prepare_Empty(t *testing.T) {