From 1e4a246d2fbeb598038771f7efab3308d8781a6a Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sat, 25 Feb 2023 15:11:07 +0000 Subject: [PATCH] Error handling. --- conn.go | 12 ++-- driver/driver_test.go | 33 ++--------- error.go | 128 ++++++++++++++++++++++++++++++++++++++++ error_test.go | 133 ++++++++++++++++++++++++++++++++++++++++-- tests/conn_test.go | 43 +++----------- tests/save_test.go | 30 ++++------ 6 files changed, 288 insertions(+), 91 deletions(-) diff --git a/conn.go b/conn.go index 640a6fa..7a53741 100644 --- a/conn.go +++ b/conn.go @@ -2,6 +2,7 @@ package sqlite3 import ( "context" + "errors" "fmt" "math" "runtime" @@ -206,7 +207,7 @@ func (c *Conn) SetInterrupt(ctx context.Context) (old context.Context) { old = c.interrupt c.interrupt = ctx - if ctx == nil || ctx == context.Background() || ctx == context.TODO() || ctx.Done() == nil { + if ctx == nil || ctx.Done() == nil { // Finalize the uncompleted SQL statement. if c.pending != nil { c.pending.Close() @@ -292,11 +293,14 @@ func (conn *Conn) Savepoint() (release func(*error)) { err := conn.Exec(fmt.Sprintf("SAVEPOINT %q;", name)) if err != nil { - return func(errp *error) { - if *errp == nil { - *errp = err + if errors.Is(err, INTERRUPT) { + return func(errp *error) { + if *errp == nil { + *errp = err + } } } + panic(err) } return func(errp *error) { diff --git a/driver/driver_test.go b/driver/driver_test.go index a5cfa8f..2b68dff 100644 --- a/driver/driver_test.go +++ b/driver/driver_test.go @@ -25,15 +25,8 @@ func Test_Open_dir(t *testing.T) { if err == nil { t.Fatal("want error") } - var serr *sqlite3.Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.CANTOPEN { - t.Errorf("got %d, want sqlite3.CANTOPEN", rc) - } - if got := err.Error(); got != `sqlite3: unable to open database file` { - t.Error("got message: ", got) + if !errors.Is(err, sqlite3.CANTOPEN) { + t.Errorf("got %v, want sqlite3.CANTOPEN", err) } } @@ -95,20 +88,13 @@ func Test_Open_txLock(t *testing.T) { if err == nil { t.Error("want error") } - var serr *sqlite3.Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.BUSY { - t.Errorf("got %d, want sqlite3.BUSY", rc) + if !errors.Is(err, sqlite3.BUSY) { + t.Errorf("got %v, want sqlite3.BUSY", err) } var terr interface{ Temporary() bool } if !errors.As(err, &terr) || !terr.Temporary() { t.Error("not temporary", err) } - if got := err.Error(); got != `sqlite3: database is locked` { - t.Error("got message: ", got) - } err = tx1.Commit() if err != nil { @@ -161,15 +147,8 @@ func Test_BeginTx(t *testing.T) { if err == nil { t.Error("want error") } - var serr *sqlite3.Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.READONLY { - t.Errorf("got %d, want sqlite3.READONLY", rc) - } - if got := err.Error(); got != `sqlite3: attempt to write a readonly database` { - t.Error("got message: ", got) + if !errors.Is(err, sqlite3.READONLY) { + t.Errorf("got %v, want sqlite3.READONLY", err) } err = tx2.Commit() diff --git a/error.go b/error.go index 74bc287..1c96733 100644 --- a/error.go +++ b/error.go @@ -50,16 +50,144 @@ func (e *Error) Error() string { return b.String() } +// Is tests whether this error matches a given [ErrorCode] or [ExtendedErrorCode]. +// +// This makes it possible to do: +// +// if errors.Is(err, sqlite3.BUSY) { +// // ... handle BUSY +// } +func (e *Error) Is(err error) bool { + switch c := err.(type) { + case ErrorCode: + return c == e.Code() + case ExtendedErrorCode: + return c == e.ExtendedCode() + } + return false +} + // Temporary returns true for [BUSY] errors. func (e *Error) Temporary() bool { return e.Code() == BUSY } +// Timeout returns true for [BUSY_TIMEOUT] errors. +func (e *Error) Timeout() bool { + return e.ExtendedCode() == BUSY_TIMEOUT +} + // SQL returns the SQL starting at the token that triggered a syntax error. func (e *Error) SQL() string { return e.sql } +// Error implements the error interface. +func (e ErrorCode) Error() string { + switch e { + case _OK: + return "sqlite3: not an error" + case _ROW: + return "sqlite3: another row available" + case _DONE: + return "sqlite3: no more rows available" + + case ERROR: + return "sqlite3: SQL logic error" + case INTERNAL: + break + case PERM: + return "sqlite3: access permission denied" + case ABORT: + return "sqlite3: query aborted" + case BUSY: + return "sqlite3: database is locked" + case LOCKED: + return "sqlite3: database table is locked" + case NOMEM: + return "sqlite3: out of memory" + case READONLY: + return "sqlite3: attempt to write a readonly database" + case INTERRUPT: + return "sqlite3: interrupted" + case IOERR: + return "sqlite3: disk I/O error" + case CORRUPT: + return "sqlite3: database disk image is malformed" + case NOTFOUND: + return "sqlite3: unknown operation" + case FULL: + return "sqlite3: database or disk is full" + case CANTOPEN: + return "sqlite3: unable to open database file" + case PROTOCOL: + return "sqlite3: locking protocol" + case FORMAT: + break + case SCHEMA: + return "sqlite3: database schema has changed" + case TOOBIG: + return "sqlite3: string or blob too big" + case CONSTRAINT: + return "sqlite3: constraint failed" + case MISMATCH: + return "sqlite3: datatype mismatch" + case MISUSE: + return "sqlite3: bad parameter or other API misuse" + case NOLFS: + break + case AUTH: + return "sqlite3: authorization denied" + case EMPTY: + break + case RANGE: + return "sqlite3: column index out of range" + case NOTADB: + return "sqlite3: file is not a database" + case NOTICE: + return "sqlite3: notification message" + case WARNING: + return "sqlite3: warning message" + } + return "sqlite3: unknown error" +} + +// Temporary returns true for [BUSY] errors. +func (e ErrorCode) Temporary() bool { + return e == BUSY +} + +// Error implements the error interface. +func (e ExtendedErrorCode) Error() string { + switch x := ErrorCode(e); { + case e == ABORT_ROLLBACK: + return "sqlite3: abort due to ROLLBACK" + case x < _ROW: + return x.Error() + case e == _ROW: + return "sqlite3: another row available" + case e == _DONE: + return "sqlite3: no more rows available" + } + return "sqlite3: unknown error" +} + +// Is tests whether this error matches a given [ErrorCode]. +func (e ExtendedErrorCode) Is(err error) bool { + c, ok := err.(ErrorCode) + return ok && c == ErrorCode(e) +} + +// Temporary returns true for [BUSY] errors. +func (e ExtendedErrorCode) Temporary() bool { + return ErrorCode(e) == BUSY +} + +// Timeout returns true for [BUSY_TIMEOUT] errors. +func (e ExtendedErrorCode) Timeout() bool { + return e == BUSY_TIMEOUT +} + type errorString string func (e errorString) Error() string { return string(e) } diff --git a/error_test.go b/error_test.go index 1e8d769..72daf63 100644 --- a/error_test.go +++ b/error_test.go @@ -1,26 +1,147 @@ package sqlite3 import ( + "context" + "errors" "strings" "testing" ) +func Test_assertErr(t *testing.T) { + err := assertErr() + if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:11)") { + t.Errorf("got %q", s) + } +} + func TestError(t *testing.T) { err := Error{code: 0x8080} if rc := err.Code(); rc != 0x80 { t.Errorf("got %#x, want 0x80", rc) } + if !errors.Is(&err, ErrorCode(0x80)) { + t.Errorf("want true") + } if rc := err.ExtendedCode(); rc != 0x8080 { t.Errorf("got %#x, want 0x8080", rc) } + if !errors.Is(&err, ExtendedErrorCode(0x8080)) { + t.Errorf("want true") + } if s := err.Error(); s != "sqlite3: 32896" { t.Errorf("got %q", s) } -} - -func Test_assertErr(t *testing.T) { - err := assertErr() - if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:22)") { - t.Errorf("got %q", s) + if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) { + t.Errorf("want true") + } +} + +func TestError_Temporary(t *testing.T) { + tests := []struct { + name string + code uint64 + want bool + }{ + {"ERROR", uint64(ERROR), false}, + {"BUSY", uint64(BUSY), true}, + {"BUSY_RECOVERY", uint64(BUSY_RECOVERY), true}, + {"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), true}, + {"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + { + err := &Error{code: tt.code} + if got := err.Temporary(); got != tt.want { + t.Errorf("Error.Temporary(%d) = %v, want %v", tt.code, got, tt.want) + } + } + { + err := ErrorCode(tt.code) + if got := err.Temporary(); got != tt.want { + t.Errorf("ErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want) + } + } + { + err := ExtendedErrorCode(tt.code) + if got := err.Temporary(); got != tt.want { + t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want) + } + } + }) + } +} + +func TestError_Timeout(t *testing.T) { + tests := []struct { + name string + code uint64 + want bool + }{ + {"ERROR", uint64(ERROR), false}, + {"BUSY", uint64(BUSY), false}, + {"BUSY_RECOVERY", uint64(BUSY_RECOVERY), false}, + {"BUSY_SNAPSHOT", uint64(BUSY_SNAPSHOT), false}, + {"BUSY_TIMEOUT", uint64(BUSY_TIMEOUT), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + { + err := &Error{code: tt.code} + if got := err.Timeout(); got != tt.want { + t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want) + } + } + { + err := ExtendedErrorCode(tt.code) + if got := err.Timeout(); got != tt.want { + t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want) + } + } + }) + } +} + +func Test_ErrorCode_Error(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Test all error codes. + for i := 0; i == int(ErrorCode(i)); i++ { + want := "sqlite3: " + r, _ := db.api.errstr.Call(context.TODO(), uint64(i)) + if r != nil { + want += db.mem.readString(uint32(r[0]), _MAX_STRING) + } + + got := ErrorCode(i).Error() + if got != want { + t.Fatalf("got %q, want %q, with %d", got, want, i) + } + } +} + +func Test_ExtendedErrorCode_Error(t *testing.T) { + db, err := Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Test all extended error codes. + for i := 0; i == int(ExtendedErrorCode(i)); i++ { + want := "sqlite3: " + r, _ := db.api.errstr.Call(context.TODO(), uint64(i)) + if r != nil { + want += db.mem.readString(uint32(r[0]), _MAX_STRING) + } + + got := ExtendedErrorCode(i).Error() + if got != want { + t.Fatalf("got %q, want %q, with %d", got, want, i) + } } } diff --git a/tests/conn_test.go b/tests/conn_test.go index 536c92c..5ea5b81 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -17,15 +17,8 @@ func TestConn_Open_dir(t *testing.T) { if err == nil { t.Fatal("want error") } - var serr *sqlite3.Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.CANTOPEN { - t.Errorf("got %d, want sqlite3.CANTOPEN", rc) - } - if got := err.Error(); got != `sqlite3: unable to open database file` { - t.Error("got message: ", got) + if !errors.Is(err, sqlite3.CANTOPEN) { + t.Errorf("got %v, want sqlite3.CANTOPEN", err) } } @@ -53,12 +46,8 @@ func TestConn_Close_BUSY(t *testing.T) { if err == nil { t.Fatal("want error") } - var serr *sqlite3.Error - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.BUSY { - t.Errorf("got %d, want sqlite3.BUSY", rc) + if !errors.Is(err, sqlite3.BUSY) { + t.Errorf("got %v, want sqlite3.BUSY", err) } var terr interface{ Temporary() bool } if !errors.As(err, &terr) || !terr.Temporary() { @@ -85,7 +74,7 @@ func TestConn_SetInterrupt(t *testing.T) { t.Fatal(err) } - db.SetInterrupt(nil) + db.SetInterrupt(context.Background()) stmt, _, err := db.Prepare(` WITH RECURSIVE @@ -106,30 +95,16 @@ func TestConn_SetInterrupt(t *testing.T) { db.SetInterrupt(ctx) cancel() - var serr *sqlite3.Error - // Interrupting works. err = stmt.Exec() - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) } // Interrupting sticks. err = db.Exec(`SELECT 1`) - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) } ctx, cancel = context.WithCancel(context.Background()) diff --git a/tests/save_test.go b/tests/save_test.go index 77d46f5..23b88f2 100644 --- a/tests/save_test.go +++ b/tests/save_test.go @@ -163,28 +163,16 @@ func TestConn_Savepoint_interrupt(t *testing.T) { t.Fatal(err) } - checkInterrupt := func(err error) { - var serr *sqlite3.Error - if err == nil { - t.Fatal("want error") - } - if !errors.As(err, &serr) { - t.Fatalf("got %T, want sqlite3.Error", err) - } - if rc := serr.Code(); rc != sqlite3.INTERRUPT { - t.Errorf("got %d, want sqlite3.INTERRUPT", rc) - } - if got := err.Error(); got != `sqlite3: interrupted` { - t.Error("got message: ", got) - } - } - cancel() db.Savepoint()(&err) - checkInterrupt(err) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) + } err = db.Exec(`INSERT INTO test(col) VALUES(4)`) - checkInterrupt(err) + if !errors.Is(err, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) + } err = context.Canceled release2(&err) @@ -194,9 +182,11 @@ func TestConn_Savepoint_interrupt(t *testing.T) { var nilErr error release1(&nilErr) - checkInterrupt(nilErr) + if !errors.Is(nilErr, sqlite3.INTERRUPT) { + t.Errorf("got %v, want sqlite3.INTERRUPT", err) + } - db.SetInterrupt(nil) + db.SetInterrupt(context.Background()) stmt, _, err := db.Prepare(`SELECT count(*) FROM test`) if err != nil { t.Fatal(err)