Use finalizers to detect unclosed connections.

This commit is contained in:
Nuno Cruces
2023-03-03 13:15:24 +00:00
parent 416c3863a0
commit 35b1a97b88
5 changed files with 27 additions and 5 deletions

View File

@@ -59,7 +59,7 @@ As a work around for other Unixes, you can use [`nolock=1`](https://www.sqlite.o
- [ ] snapshots - [ ] snapshots
- [ ] session extension - [ ] session extension
- [ ] resumable bulk update - [ ] resumable bulk update
- [ ] shared cache mode - [ ] shared-cache mode
- [ ] unlock-notify - [ ] unlock-notify
- [ ] custom SQL functions - [ ] custom SQL functions
- [ ] custom VFSes - [ ] custom VFSes

View File

@@ -8,7 +8,6 @@ import "io"
type ZeroBlob int64 type ZeroBlob int64
// Blob is a handle to an open BLOB. // Blob is a handle to an open BLOB.
//
// It implements [io.ReadWriteSeeker] for incremental BLOB I/O. // It implements [io.ReadWriteSeeker] for incremental BLOB I/O.
// //
// https://www.sqlite.org/c3ref/blob.html // https://www.sqlite.org/c3ref/blob.html

16
conn.go
View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"math" "math"
"net/url" "net/url"
"runtime"
"strings" "strings"
"sync/atomic" "sync/atomic"
"unsafe" "unsafe"
@@ -30,8 +31,8 @@ type Conn struct {
} }
// Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI]. // Open calls [OpenFlags] with [OPEN_READWRITE], [OPEN_CREATE] and [OPEN_URI].
func Open(filename string) (conn *Conn, err error) { func Open(filename string) (*Conn, error) {
return OpenFlags(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI) return openFlags(filename, OPEN_READWRITE|OPEN_CREATE|OPEN_URI)
} }
// OpenFlags opens an SQLite database file as specified by the filename argument. // OpenFlags opens an SQLite database file as specified by the filename argument.
@@ -41,7 +42,11 @@ func Open(filename string) (conn *Conn, err error) {
// sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)") // sqlite3.Open("file:demo.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)")
// //
// https://www.sqlite.org/c3ref/open.html // https://www.sqlite.org/c3ref/open.html
func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) { func OpenFlags(filename string, flags OpenFlag) (*Conn, error) {
return openFlags(filename, flags)
}
func openFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
ctx := context.Background() ctx := context.Background()
module, err := sqlite3.instantiateModule(ctx) module, err := sqlite3.instantiateModule(ctx)
if err != nil { if err != nil {
@@ -50,6 +55,8 @@ func OpenFlags(filename string, flags OpenFlag) (conn *Conn, err error) {
defer func() { defer func() {
if conn == nil { if conn == nil {
module.Close(ctx) module.Close(ctx)
} else {
runtime.SetFinalizer(conn, finalizer[Conn](3))
} }
}() }()
@@ -109,6 +116,7 @@ func (c *Conn) Close() error {
} }
c.handle = 0 c.handle = 0
runtime.SetFinalizer(c, nil)
return c.mem.mod.Close(c.ctx) return c.mem.mod.Close(c.ctx)
} }
@@ -133,9 +141,11 @@ func (c *Conn) MustPrepare(sql string) *Stmt {
panic(err) panic(err)
} }
if s == nil { if s == nil {
s.Close()
panic(emptyErr) panic(emptyErr)
} }
if !emptyStatement(tail) { if !emptyStatement(tail) {
s.Close()
panic(tailErr) panic(tailErr)
} }
return s return s

View File

@@ -1,6 +1,7 @@
package sqlite3 package sqlite3
import ( import (
"fmt"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@@ -215,3 +216,11 @@ func assertErr() errorString {
} }
return errorString(msg) return errorString(msg)
} }
func finalizer[T any](skip int) func(*T) {
msg := fmt.Sprintf("sqlite3: %T not closed", new(T))
if _, file, line, ok := runtime.Caller(skip + 1); ok && skip >= 0 {
msg += " (" + file + ":" + strconv.Itoa(line) + ")"
}
return func(*T) { panic(errorString(msg)) }
}

View File

@@ -29,6 +29,7 @@ func TestConn_Transaction_exec(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer stmt.Close()
if stmt.Step() { if stmt.Step() {
return stmt.ColumnInt(0) return stmt.ColumnInt(0)
} }
@@ -117,6 +118,7 @@ func TestConn_Transaction_panic(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer stmt.Close()
if stmt.Step() { if stmt.Step() {
got := stmt.ColumnInt(0) got := stmt.ColumnInt(0)
if got != 1 { if got != 1 {
@@ -275,6 +277,7 @@ func TestConn_Savepoint_exec(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer stmt.Close()
if stmt.Step() { if stmt.Step() {
return stmt.ColumnInt(0) return stmt.ColumnInt(0)
} }
@@ -361,6 +364,7 @@ func TestConn_Savepoint_panic(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
defer stmt.Close()
if stmt.Step() { if stmt.Step() {
got := stmt.ColumnInt(0) got := stmt.ColumnInt(0)
if got != 1 { if got != 1 {