diff --git a/driver/driver.go b/driver/driver.go index 33eb740..a65d166 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -269,6 +269,12 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name return nil, driver.ErrSkip } + if savept, ok := ctx.(*saveptCtx); ok { + // Called from driver.Savepoint. + savept.Savepoint = c.Savepoint() + return resultRowsAffected(-1), nil + } + old := c.Conn.SetInterrupt(ctx) defer c.Conn.SetInterrupt(old) diff --git a/driver/savepoint.go b/driver/savepoint.go new file mode 100644 index 0000000..0e2f48f --- /dev/null +++ b/driver/savepoint.go @@ -0,0 +1,27 @@ +package driver + +import ( + "database/sql" + "time" + + "github.com/ncruces/go-sqlite3" +) + +// Savepoint establishes a new transaction savepoint. +// +// https://www.sqlite.org/lang_savepoint.html +func Savepoint(tx *sql.Tx) sqlite3.Savepoint { + var ctx saveptCtx + tx.ExecContext(&ctx, "") + return ctx.Savepoint +} + +type saveptCtx struct{ sqlite3.Savepoint } + +func (*saveptCtx) Deadline() (deadline time.Time, ok bool) { return } + +func (*saveptCtx) Done() <-chan struct{} { return nil } + +func (*saveptCtx) Err() error { return nil } + +func (*saveptCtx) Value(key any) any { return nil } diff --git a/ext/blob/blob_test.go b/ext/blob/blob_test.go index b2265f2..e8e8206 100644 --- a/ext/blob/blob_test.go +++ b/ext/blob/blob_test.go @@ -21,7 +21,6 @@ func Example() { if err != nil { log.Fatal(err) } - defer os.Remove("demo.db") defer db.Close() _, err = db.Exec(`CREATE TABLE IF NOT EXISTS test (col)`) diff --git a/tx.go b/tx.go index 5fe144c..64fbe6b 100644 --- a/tx.go +++ b/tx.go @@ -7,6 +7,7 @@ import ( "math/rand" "runtime" "strconv" + "strings" ) // Tx is an in-progress database transaction. @@ -119,17 +120,8 @@ type Savepoint struct { // // https://www.sqlite.org/lang_savepoint.html func (c *Conn) Savepoint() Savepoint { - name := "sqlite3.Savepoint" - var pc [1]uintptr - if n := runtime.Callers(2, pc[:]); n > 0 { - frames := runtime.CallersFrames(pc[:n]) - frame, _ := frames.Next() - if frame.Function != "" { - name = frame.Function - } - } // Names can be reused; this makes catching bugs more likely. - name += "#" + strconv.Itoa(int(rand.Int31())) + name := saveptName() + "_" + strconv.Itoa(int(rand.Int31())) err := c.txExecInterrupted(fmt.Sprintf("SAVEPOINT %q;", name)) if err != nil { @@ -138,6 +130,27 @@ func (c *Conn) Savepoint() Savepoint { return Savepoint{c: c, name: name} } +func saveptName() (name string) { + defer func() { + if name == "" { + name = "sqlite3.Savepoint" + } + }() + + var pc [8]uintptr + n := runtime.Callers(3, pc[:]) + if n <= 0 { + return "" + } + frames := runtime.CallersFrames(pc[:n]) + frame, more := frames.Next() + for more && (strings.HasPrefix(frame.Function, "database/sql.") || + strings.HasPrefix(frame.Function, "github.com/ncruces/go-sqlite3/driver.")) { + frame, more = frames.Next() + } + return frame.Function +} + // Release releases the savepoint rolling back any changes // if *error points to a non-nil error. //