diff --git a/conn.go b/conn.go index c8b64de..8e12e0c 100644 --- a/conn.go +++ b/conn.go @@ -5,6 +5,7 @@ import ( "fmt" "math" "runtime" + "sync" ) // Conn is a database connection handle. @@ -17,6 +18,7 @@ type Conn struct { handle uint32 arena arena + mtx sync.Mutex interrupt context.Context waiter chan struct{} pending *Stmt @@ -244,13 +246,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { case <-ctx.Done(): // Done was closed. - // This is safe to call from a goroutine - // because it doesn't touch the C stack. - _, err := c.api.interrupt.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) - } - + c.sendInterrupt() // Wait for the next call to SetInterrupt. <-waiter } @@ -265,11 +261,19 @@ func (c *Conn) checkInterrupt() bool { if c.interrupt == nil || c.interrupt.Err() == nil { return false } + c.sendInterrupt() + return true +} + +func (c *Conn) sendInterrupt() { + c.mtx.Lock() + defer c.mtx.Unlock() + // This is safe to call from a goroutine + // because it doesn't touch the C stack. _, err := c.api.interrupt.Call(c.ctx, uint64(c.handle)) if err != nil { panic(err) } - return true } // Savepoint creates a named SQLite transaction using SAVEPOINT.