From 1190c21684dcaed9ac2e14b0b0faaabe7e2fcd4d Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 24 Feb 2023 15:06:19 +0000 Subject: [PATCH] Refactor. --- conn.go | 224 +++++++++++++++++++++++------------ save.go | 77 ------------ stmt.go | 15 +++ util_test.go => stmt_test.go | 0 util.go | 16 --- 5 files changed, 163 insertions(+), 169 deletions(-) delete mode 100644 save.go rename util_test.go => stmt_test.go (100%) delete mode 100644 util.go diff --git a/conn.go b/conn.go index 1f747e6..c8b64de 100644 --- a/conn.go +++ b/conn.go @@ -2,7 +2,9 @@ package sqlite3 import ( "context" + "fmt" "math" + "runtime" ) // Conn is a database connection handle. @@ -91,6 +93,64 @@ func (c *Conn) Close() error { return c.mem.mod.Close(c.ctx) } +// Exec is a convenience function that allows an application to run +// multiple statements of SQL without having to use a lot of code. +// +// https://www.sqlite.org/c3ref/exec.html +func (c *Conn) Exec(sql string) error { + c.checkInterrupt() + defer c.arena.reset() + sqlPtr := c.arena.string(sql) + + r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0) + if err != nil { + panic(err) + } + return c.error(r[0]) +} + +// Prepare calls [Conn.PrepareFlags] with no flags. +func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { + return c.PrepareFlags(sql, 0) +} + +// PrepareFlags compiles the first SQL statement in sql; +// tail is left pointing to what remains uncompiled. +// If the input text contains no SQL (if the input is an empty string or a comment), +// both stmt and err will be nil. +// +// https://www.sqlite.org/c3ref/prepare.html +func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) { + if emptyStatement(sql) { + return nil, "", nil + } + + defer c.arena.reset() + stmtPtr := c.arena.new(ptrlen) + tailPtr := c.arena.new(ptrlen) + sqlPtr := c.arena.string(sql) + + r, err := c.api.prepare.Call(c.ctx, uint64(c.handle), + uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), + uint64(stmtPtr), uint64(tailPtr)) + if err != nil { + panic(err) + } + + stmt = &Stmt{c: c} + stmt.handle = c.mem.readUint32(stmtPtr) + i := c.mem.readUint32(tailPtr) + tail = sql[i-sqlPtr:] + + if err := c.error(r[0], sql); err != nil { + return nil, "", err + } + if stmt.handle == 0 { + return nil, "", nil + } + return +} + // GetAutocommit tests the connection for auto-commit mode. // // https://www.sqlite.org/c3ref/get_autocommit.html @@ -102,6 +162,31 @@ func (c *Conn) GetAutocommit() bool { return r[0] != 0 } +// LastInsertRowID returns the rowid of the most recent successful INSERT +// on the database connection. +// +// https://www.sqlite.org/c3ref/last_insert_rowid.html +func (c *Conn) LastInsertRowID() uint64 { + r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle)) + if err != nil { + panic(err) + } + return r[0] +} + +// Changes returns the number of rows modified, inserted or deleted +// by the most recently completed INSERT, UPDATE or DELETE statement +// on the database connection. +// +// https://www.sqlite.org/c3ref/changes.html +func (c *Conn) Changes() uint64 { + r, err := c.api.changes.Call(c.ctx, uint64(c.handle)) + if err != nil { + panic(err) + } + return r[0] +} + // SetInterrupt interrupts a long-running query when a context is done. // // Subsequent uses of the connection will return [INTERRUPT] @@ -187,87 +272,74 @@ func (c *Conn) checkInterrupt() bool { return true } -// Exec is a convenience function that allows an application to run -// multiple statements of SQL without having to use a lot of code. +// Savepoint creates a named SQLite transaction using SAVEPOINT. // -// https://www.sqlite.org/c3ref/exec.html -func (c *Conn) Exec(sql string) error { - c.checkInterrupt() - defer c.arena.reset() - sqlPtr := c.arena.string(sql) - - r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0) - if err != nil { - panic(err) - } - return c.error(r[0]) -} - -// Prepare calls [Conn.PrepareFlags] with no flags. -func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) { - return c.PrepareFlags(sql, 0) -} - -// PrepareFlags compiles the first SQL statement in sql; -// tail is left pointing to what remains uncompiled. -// If the input text contains no SQL (if the input is an empty string or a comment), -// both stmt and err will be nil. +// On success Savepoint returns a release func that will call +// either RELEASE or ROLLBACK depending on whether the parameter *error +// points to a nil or non-nil error. // -// https://www.sqlite.org/c3ref/prepare.html -func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) { - if emptyStatement(sql) { - return nil, "", nil - } - - defer c.arena.reset() - stmtPtr := c.arena.new(ptrlen) - tailPtr := c.arena.new(ptrlen) - sqlPtr := c.arena.string(sql) - - r, err := c.api.prepare.Call(c.ctx, uint64(c.handle), - uint64(sqlPtr), uint64(len(sql)+1), uint64(flags), - uint64(stmtPtr), uint64(tailPtr)) - if err != nil { - panic(err) - } - - stmt = &Stmt{c: c} - stmt.handle = c.mem.readUint32(stmtPtr) - i := c.mem.readUint32(tailPtr) - tail = sql[i-sqlPtr:] - - if err := c.error(r[0], sql); err != nil { - return nil, "", err - } - if stmt.handle == 0 { - return nil, "", nil - } - return -} - -// LastInsertRowID returns the rowid of the most recent successful INSERT -// on the database connection. +// This is meant to be deferred: // -// https://www.sqlite.org/c3ref/last_insert_rowid.html -func (c *Conn) LastInsertRowID() uint64 { - r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) - } - return r[0] -} - -// Changes returns the number of rows modified, inserted or deleted -// by the most recently completed INSERT, UPDATE or DELETE statement -// on the database connection. +// func doWork(conn *sqlite3.Conn) (err error) { +// defer conn.Savepoint()(&err) // -// https://www.sqlite.org/c3ref/changes.html -func (c *Conn) Changes() uint64 { - r, err := c.api.changes.Call(c.ctx, uint64(c.handle)) - if err != nil { - panic(err) +// // ... do work in the transaction +// } +func (conn *Conn) Savepoint() (release func(*error)) { + name := "sqlite3.Savepoint" // names can be reused + var pc [1]uintptr + if n := runtime.Callers(2, pc[:]); n > 0 { + frames := runtime.CallersFrames(pc[:n]) + frame, _ := frames.Next() + if frame.Function != "" { + name = frame.Function + } + } + + err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name)) + if err != nil { + return func(errp *error) { + if *errp == nil { + *errp = err + } + } + } + + return func(errp *error) { + recovered := recover() + if recovered != nil { + defer panic(recovered) + } + + if conn.GetAutocommit() { + // There is nothing to commit/rollback. + return + } + + if *errp == nil && recovered == nil { + // Success path. + // RELEASE the savepoint successfully. + *errp = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) + if *errp == nil { + return + } + // Possible interrupt, fall through to the error path. + } + + // Error path. + // Always ROLLBACK even if the connection has been interrupted. + old := conn.SetInterrupt(context.Background()) + defer conn.SetInterrupt(old) + + err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name)) + if err != nil { + panic(err) + } + err = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) + if err != nil { + panic(err) + } } - return r[0] } func (c *Conn) error(rc uint64, sql ...string) error { diff --git a/save.go b/save.go deleted file mode 100644 index a7cef4b..0000000 --- a/save.go +++ /dev/null @@ -1,77 +0,0 @@ -package sqlite3 - -import ( - "context" - "fmt" - "runtime" -) - -// Savepoint creates a named SQLite transaction using SAVEPOINT. -// -// On success Savepoint returns a release func that will call -// either RELEASE or ROLLBACK depending on whether the parameter *error -// points to a nil or non-nil error. -// -// This is meant to be deferred: -// -// func doWork(conn *sqlite3.Conn) (err error) { -// defer conn.Savepoint()(&err) -// -// // ... do work in the transaction -// } -func (conn *Conn) Savepoint() (release func(*error)) { - name := "sqlite3.Savepoint" // names can be reused - var pc [1]uintptr - if n := runtime.Callers(2, pc[:]); n > 0 { - frames := runtime.CallersFrames(pc[:n]) - frame, _ := frames.Next() - if frame.Function != "" { - name = frame.Function - } - } - - err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name)) - if err != nil { - return func(errp *error) { - if *errp == nil { - *errp = err - } - } - } - - return func(errp *error) { - recovered := recover() - if recovered != nil { - defer panic(recovered) - } - - if conn.GetAutocommit() { - // There is nothing to commit/rollback. - return - } - - if *errp == nil && recovered == nil { - // Success path. - // RELEASE the savepoint successfully. - *errp = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) - if *errp == nil { - return - } - // Possible interrupt, fall through to the error path. - } - - // Error path. - // Always ROLLBACK even if the connection has been interrupted. - old := conn.SetInterrupt(context.Background()) - defer conn.SetInterrupt(old) - - err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name)) - if err != nil { - panic(err) - } - err = conn.Exec(fmt.Sprintf("RELEASE %q;", name)) - if err != nil { - panic(err) - } - } -} diff --git a/stmt.go b/stmt.go index 919ea2a..1215482 100644 --- a/stmt.go +++ b/stmt.go @@ -445,3 +445,18 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte { mem := s.c.mem.view(ptr, uint32(r[0])) return append(buf[0:0], mem...) } + +// Return true if stmt is an empty SQL statement. +// This is used as an optimization. +// It's OK to always return false here. +func emptyStatement(stmt string) bool { + for _, b := range []byte(stmt) { + switch b { + case ' ', '\n', '\r', '\t', '\v', '\f': + case ';': + default: + return false + } + } + return true +} diff --git a/util_test.go b/stmt_test.go similarity index 100% rename from util_test.go rename to stmt_test.go diff --git a/util.go b/util.go deleted file mode 100644 index 7a5bf13..0000000 --- a/util.go +++ /dev/null @@ -1,16 +0,0 @@ -package sqlite3 - -// Return true if stmt is an empty SQL statement. -// This is used as an optimization. -// It's OK to always return false here. -func emptyStatement(stmt string) bool { - for _, b := range []byte(stmt) { - switch b { - case ' ', '\n', '\r', '\t', '\v', '\f': - case ';': - default: - return false - } - } - return true -}