Refactor.

This commit is contained in:
Nuno Cruces
2023-02-24 15:06:19 +00:00
parent 8c28c3a6f4
commit 1190c21684
5 changed files with 163 additions and 169 deletions

224
conn.go
View File

@@ -2,7 +2,9 @@ package sqlite3
import (
"context"
"fmt"
"math"
"runtime"
)
// Conn is a database connection handle.
@@ -91,6 +93,64 @@ func (c *Conn) Close() error {
return c.mem.mod.Close(c.ctx)
}
// Exec is a convenience function that allows an application to run
// multiple statements of SQL without having to use a lot of code.
//
// https://www.sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
panic(err)
}
return c.error(r[0])
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
}
// PrepareFlags compiles the first SQL statement in sql;
// tail is left pointing to what remains uncompiled.
// If the input text contains no SQL (if the input is an empty string or a comment),
// both stmt and err will be nil.
//
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
if emptyStatement(sql) {
return nil, "", nil
}
defer c.arena.reset()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
panic(err)
}
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
i := c.mem.readUint32(tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0], sql); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
return nil, "", nil
}
return
}
// GetAutocommit tests the connection for auto-commit mode.
//
// https://www.sqlite.org/c3ref/get_autocommit.html
@@ -102,6 +162,31 @@ func (c *Conn) GetAutocommit() bool {
return r[0] != 0
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
// on the database connection.
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() uint64 {
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
}
// Changes returns the number of rows modified, inserted or deleted
// by the most recently completed INSERT, UPDATE or DELETE statement
// on the database connection.
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() uint64 {
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
}
// SetInterrupt interrupts a long-running query when a context is done.
//
// Subsequent uses of the connection will return [INTERRUPT]
@@ -187,87 +272,74 @@ func (c *Conn) checkInterrupt() bool {
return true
}
// Exec is a convenience function that allows an application to run
// multiple statements of SQL without having to use a lot of code.
// Savepoint creates a named SQLite transaction using SAVEPOINT.
//
// https://www.sqlite.org/c3ref/exec.html
func (c *Conn) Exec(sql string) error {
c.checkInterrupt()
defer c.arena.reset()
sqlPtr := c.arena.string(sql)
r, err := c.api.exec.Call(c.ctx, uint64(c.handle), uint64(sqlPtr), 0, 0, 0)
if err != nil {
panic(err)
}
return c.error(r[0])
}
// Prepare calls [Conn.PrepareFlags] with no flags.
func (c *Conn) Prepare(sql string) (stmt *Stmt, tail string, err error) {
return c.PrepareFlags(sql, 0)
}
// PrepareFlags compiles the first SQL statement in sql;
// tail is left pointing to what remains uncompiled.
// If the input text contains no SQL (if the input is an empty string or a comment),
// both stmt and err will be nil.
// On success Savepoint returns a release func that will call
// either RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error.
//
// https://www.sqlite.org/c3ref/prepare.html
func (c *Conn) PrepareFlags(sql string, flags PrepareFlag) (stmt *Stmt, tail string, err error) {
if emptyStatement(sql) {
return nil, "", nil
}
defer c.arena.reset()
stmtPtr := c.arena.new(ptrlen)
tailPtr := c.arena.new(ptrlen)
sqlPtr := c.arena.string(sql)
r, err := c.api.prepare.Call(c.ctx, uint64(c.handle),
uint64(sqlPtr), uint64(len(sql)+1), uint64(flags),
uint64(stmtPtr), uint64(tailPtr))
if err != nil {
panic(err)
}
stmt = &Stmt{c: c}
stmt.handle = c.mem.readUint32(stmtPtr)
i := c.mem.readUint32(tailPtr)
tail = sql[i-sqlPtr:]
if err := c.error(r[0], sql); err != nil {
return nil, "", err
}
if stmt.handle == 0 {
return nil, "", nil
}
return
}
// LastInsertRowID returns the rowid of the most recent successful INSERT
// on the database connection.
// This is meant to be deferred:
//
// https://www.sqlite.org/c3ref/last_insert_rowid.html
func (c *Conn) LastInsertRowID() uint64 {
r, err := c.api.lastRowid.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
}
return r[0]
}
// Changes returns the number of rows modified, inserted or deleted
// by the most recently completed INSERT, UPDATE or DELETE statement
// on the database connection.
// func doWork(conn *sqlite3.Conn) (err error) {
// defer conn.Savepoint()(&err)
//
// https://www.sqlite.org/c3ref/changes.html
func (c *Conn) Changes() uint64 {
r, err := c.api.changes.Call(c.ctx, uint64(c.handle))
if err != nil {
panic(err)
// // ... do work in the transaction
// }
func (conn *Conn) Savepoint() (release func(*error)) {
name := "sqlite3.Savepoint" // names can be reused
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
if frame.Function != "" {
name = frame.Function
}
}
err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
return func(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if conn.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
// Success path.
// RELEASE the savepoint successfully.
*errp = conn.Exec(fmt.Sprintf("RELEASE %q;", name))
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
}
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := conn.SetInterrupt(context.Background())
defer conn.SetInterrupt(old)
err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
if err != nil {
panic(err)
}
err = conn.Exec(fmt.Sprintf("RELEASE %q;", name))
if err != nil {
panic(err)
}
}
return r[0]
}
func (c *Conn) error(rc uint64, sql ...string) error {

77
save.go
View File

@@ -1,77 +0,0 @@
package sqlite3
import (
"context"
"fmt"
"runtime"
)
// Savepoint creates a named SQLite transaction using SAVEPOINT.
//
// On success Savepoint returns a release func that will call
// either RELEASE or ROLLBACK depending on whether the parameter *error
// points to a nil or non-nil error.
//
// This is meant to be deferred:
//
// func doWork(conn *sqlite3.Conn) (err error) {
// defer conn.Savepoint()(&err)
//
// // ... do work in the transaction
// }
func (conn *Conn) Savepoint() (release func(*error)) {
name := "sqlite3.Savepoint" // names can be reused
var pc [1]uintptr
if n := runtime.Callers(2, pc[:]); n > 0 {
frames := runtime.CallersFrames(pc[:n])
frame, _ := frames.Next()
if frame.Function != "" {
name = frame.Function
}
}
err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name))
if err != nil {
return func(errp *error) {
if *errp == nil {
*errp = err
}
}
}
return func(errp *error) {
recovered := recover()
if recovered != nil {
defer panic(recovered)
}
if conn.GetAutocommit() {
// There is nothing to commit/rollback.
return
}
if *errp == nil && recovered == nil {
// Success path.
// RELEASE the savepoint successfully.
*errp = conn.Exec(fmt.Sprintf("RELEASE %q;", name))
if *errp == nil {
return
}
// Possible interrupt, fall through to the error path.
}
// Error path.
// Always ROLLBACK even if the connection has been interrupted.
old := conn.SetInterrupt(context.Background())
defer conn.SetInterrupt(old)
err := conn.Exec(fmt.Sprintf("ROLLBACK TO %q;", name))
if err != nil {
panic(err)
}
err = conn.Exec(fmt.Sprintf("RELEASE %q;", name))
if err != nil {
panic(err)
}
}
}

15
stmt.go
View File

@@ -445,3 +445,18 @@ func (s *Stmt) ColumnBlob(col int, buf []byte) []byte {
mem := s.c.mem.view(ptr, uint32(r[0]))
return append(buf[0:0], mem...)
}
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}

16
util.go
View File

@@ -1,16 +0,0 @@
package sqlite3
// Return true if stmt is an empty SQL statement.
// This is used as an optimization.
// It's OK to always return false here.
func emptyStatement(stmt string) bool {
for _, b := range []byte(stmt) {
switch b {
case ' ', '\n', '\r', '\t', '\v', '\f':
case ';':
default:
return false
}
}
return true
}