From 5d6f92b73395a9edf51ae2af3b1be28a8090e2bd Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 7 Mar 2023 14:45:54 +0000 Subject: [PATCH] Documentation, tests, tweaks. --- backup.go | 6 +++--- conn.go | 15 +++++++-------- error_test.go | 15 +++++---------- tx.go | 2 +- 4 files changed, 16 insertions(+), 22 deletions(-) diff --git a/backup.go b/backup.go index de5759f..c931347 100644 --- a/backup.go +++ b/backup.go @@ -11,7 +11,7 @@ type Backup struct { // Backup backs up srcDB on the src connection to the "main" database in dstURI. // -// Backup calls [Conn.Open] to open the SQLite database file dstURI, +// Backup calls [Open] to open the SQLite database file dstURI, // and blocks until the entire backup is complete. // Use [Conn.BackupInit] for incremental backup. // @@ -28,7 +28,7 @@ func (src *Conn) Backup(srcDB, dstURI string) error { // Restore restores dstDB on the dst connection from the "main" database in srcURI. // -// Restore calls [Conn.Open] to open the SQLite database file srcURI, +// Restore calls [Open] to open the SQLite database file srcURI, // and blocks until the entire restore is complete. // // https://www.sqlite.org/backup.html @@ -48,7 +48,7 @@ func (dst *Conn) Restore(dstDB, srcURI string) error { // BackupInit initializes a backup operation to copy the content of one database into another. // -// BackupInit calls [Conn.Open] to open the SQLite database file dstURI, +// BackupInit calls [Open] to open the SQLite database file dstURI, // then initializes a backup that copies the contents of srcDB on the src connection // to the "main" database in dstURI. // diff --git a/conn.go b/conn.go index e644d54..be0c9aa 100644 --- a/conn.go +++ b/conn.go @@ -119,6 +119,8 @@ func (c *Conn) Close() error { } c.SetInterrupt(context.Background()) + c.pending.Close() + c.pending = nil r := c.call(c.api.close, uint64(c.handle)) if err := c.error(r[0]); err != nil { @@ -247,15 +249,14 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { <-c.waiter // Wait for it to finish. c.waiter = nil } + // Reset the pending statement. + if c.pending != nil { + c.pending.Reset() + } old = c.interrupt c.interrupt = ctx if ctx == nil || ctx.Done() == nil { - // Finalize the uncompleted SQL statement. - if c.pending != nil { - c.pending.Close() - c.pending = nil - } return old } @@ -263,10 +264,8 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { // an interrupt that comes before any other statements are started. if c.pending == nil { c.pending = c.MustPrepare(`SELECT 1 UNION ALL SELECT 2`) - c.pending.Step() - } else { - c.pending.Reset() } + c.pending.Step() // Don't create the goroutine if we're already interrupted. // This happens frequently while restoring to a previously interrupted state. diff --git a/error_test.go b/error_test.go index 17eb612..927b5ad 100644 --- a/error_test.go +++ b/error_test.go @@ -1,7 +1,6 @@ package sqlite3 import ( - "context" "errors" "strings" "testing" @@ -9,7 +8,7 @@ import ( func Test_assertErr(t *testing.T) { err := assertErr() - if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:11)") { + if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:10)") { t.Errorf("got %q", s) } } @@ -120,10 +119,8 @@ func Test_ErrorCode_Error(t *testing.T) { // Test all error codes. for i := 0; i == int(ErrorCode(i)); i++ { want := "sqlite3: " - r, _ := db.api.errstr.Call(context.TODO(), uint64(i)) - if r != nil { - want += db.mem.readString(uint32(r[0]), _MAX_STRING) - } + r := db.call(db.api.errstr, uint64(i)) + want += db.mem.readString(uint32(r[0]), _MAX_STRING) got := ErrorCode(i).Error() if got != want { @@ -144,10 +141,8 @@ func Test_ExtendedErrorCode_Error(t *testing.T) { // Test all extended error codes. for i := 0; i == int(ExtendedErrorCode(i)); i++ { want := "sqlite3: " - r, _ := db.api.errstr.Call(context.TODO(), uint64(i)) - if r != nil { - want += db.mem.readString(uint32(r[0]), _MAX_STRING) - } + r := db.call(db.api.errstr, uint64(i)) + want += db.mem.readString(uint32(r[0]), _MAX_STRING) got := ExtendedErrorCode(i).Error() if got != want { diff --git a/tx.go b/tx.go index 73506ac..71300cf 100644 --- a/tx.go +++ b/tx.go @@ -92,7 +92,7 @@ func (tx Tx) Commit() error { return tx.c.Exec(`COMMIT`) } -// Rollback rollsback the transaction. +// Rollback rolls back the transaction. // // https://www.sqlite.org/lang_transaction.html func (tx Tx) Rollback() error {