From 41dc46af7ee58935d44290875a50a3460dc0c340 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 28 Mar 2025 11:51:45 +0000 Subject: [PATCH] Optimize errors a bit. --- error.go | 18 +++++------------- error_test.go | 10 +++++----- internal/util/error.go | 4 ++-- internal/util/mem.go | 5 ++--- sqlite.go | 28 ++++++++++++++-------------- 5 files changed, 28 insertions(+), 37 deletions(-) diff --git a/error.go b/error.go index 6d4bd63..59982ea 100644 --- a/error.go +++ b/error.go @@ -2,7 +2,6 @@ package sqlite3 import ( "errors" - "strconv" "strings" "github.com/ncruces/go-sqlite3/internal/util" @@ -12,7 +11,6 @@ import ( // // https://sqlite.org/c3ref/errcode.html type Error struct { - str string msg string sql string code res_t @@ -29,19 +27,13 @@ func (e *Error) Code() ErrorCode { // // https://sqlite.org/rescode.html func (e *Error) ExtendedCode() ExtendedErrorCode { - return ExtendedErrorCode(e.code) + return xErrorCode(e.code) } // Error implements the error interface. func (e *Error) Error() string { var b strings.Builder - b.WriteString("sqlite3: ") - - if e.str != "" { - b.WriteString(e.str) - } else { - b.WriteString(strconv.Itoa(int(e.code))) - } + b.WriteString(util.ErrorCodeString(uint32(e.code))) if e.msg != "" { b.WriteString(": ") @@ -103,12 +95,12 @@ func (e ErrorCode) Error() string { // Temporary returns true for [BUSY] errors. func (e ErrorCode) Temporary() bool { - return e == BUSY + return e == BUSY || e == INTERRUPT } // ExtendedCode returns the extended error code for this error. func (e ErrorCode) ExtendedCode() ExtendedErrorCode { - return ExtendedErrorCode(e) + return xErrorCode(e) } // Error implements the error interface. @@ -133,7 +125,7 @@ func (e ExtendedErrorCode) As(err any) bool { // Temporary returns true for [BUSY] errors. func (e ExtendedErrorCode) Temporary() bool { - return ErrorCode(e) == BUSY + return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT } // Timeout returns true for [BUSY_TIMEOUT] errors. diff --git a/error_test.go b/error_test.go index 2ec3f49..586039a 100644 --- a/error_test.go +++ b/error_test.go @@ -43,7 +43,7 @@ func TestError(t *testing.T) { if !errors.Is(err, xErrorCode(0x8080)) { t.Errorf("want true") } - if s := err.Error(); s != "sqlite3: 32896" { + if s := err.Error(); s != "sqlite3: unknown error" { t.Errorf("got %q", s) } if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) { @@ -83,7 +83,7 @@ func TestError_Temporary(t *testing.T) { } } { - err := ExtendedErrorCode(tt.code) + err := xErrorCode(tt.code) if got := err.Temporary(); got != tt.want { t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want) } @@ -115,7 +115,7 @@ func TestError_Timeout(t *testing.T) { } } { - err := ExtendedErrorCode(tt.code) + err := xErrorCode(tt.code) if got := err.Timeout(); got != tt.want { t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want) } @@ -156,12 +156,12 @@ func Test_ExtendedErrorCode_Error(t *testing.T) { defer db.Close() // Test all extended error codes. - for i := 0; i == int(ExtendedErrorCode(i)); i++ { + for i := 0; i == int(xErrorCode(i)); i++ { want := "sqlite3: " ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i))) want += util.ReadString(db.mod, ptr, _MAX_NAME) - got := ExtendedErrorCode(i).Error() + got := xErrorCode(i).Error() if got != want { t.Fatalf("got %q, want %q, with %d", got, want, i) } diff --git a/internal/util/error.go b/internal/util/error.go index 2aecac9..76769ed 100644 --- a/internal/util/error.go +++ b/internal/util/error.go @@ -75,7 +75,7 @@ func ErrorCodeString(rc uint32) string { return "sqlite3: unable to open database file" case PROTOCOL: return "sqlite3: locking protocol" - case FORMAT: + case EMPTY: break case SCHEMA: return "sqlite3: database schema has changed" @@ -91,7 +91,7 @@ func ErrorCodeString(rc uint32) string { break case AUTH: return "sqlite3: authorization denied" - case EMPTY: + case FORMAT: break case RANGE: return "sqlite3: column index out of range" diff --git a/internal/util/mem.go b/internal/util/mem.go index d2fea08..90c0e9e 100644 --- a/internal/util/mem.go +++ b/internal/util/mem.go @@ -135,11 +135,10 @@ func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string { panic(RangeErr) } } - if i := bytes.IndexByte(buf, 0); i < 0 { - panic(NoNulErr) - } else { + if i := bytes.IndexByte(buf, 0); i >= 0 { return string(buf[:i]) } + panic(NoNulErr) } func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) { diff --git a/sqlite.go b/sqlite.go index 9e2d1d3..df32271 100644 --- a/sqlite.go +++ b/sqlite.go @@ -120,33 +120,33 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error { return nil } - err := Error{code: rc} - - if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM { + if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM { panic(util.OOMErr) } - if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 { - err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME) - } - if handle != 0 { + var msg, query string if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 { - err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) + msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) + switch { + case msg == "not an error": + msg = "" + case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]: + msg = "" + } } if len(sql) != 0 { if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 { - err.sql = sql[0][i:] + query = sql[0][i:] } } - } - switch err.msg { - case err.str, "not an error": - err.msg = "" + if msg != "" || query != "" { + return &Error{code: rc, msg: msg, sql: query} + } } - return &err + return xErrorCode(rc) } func (sqlt *sqlite) getfn(name string) api.Function {