Optimize errors a bit.

This commit is contained in:
Nuno Cruces
2025-03-28 11:51:45 +00:00
parent e5c285b783
commit 41dc46af7e
5 changed files with 28 additions and 37 deletions

View File

@@ -2,7 +2,6 @@ package sqlite3
import ( import (
"errors" "errors"
"strconv"
"strings" "strings"
"github.com/ncruces/go-sqlite3/internal/util" "github.com/ncruces/go-sqlite3/internal/util"
@@ -12,7 +11,6 @@ import (
// //
// https://sqlite.org/c3ref/errcode.html // https://sqlite.org/c3ref/errcode.html
type Error struct { type Error struct {
str string
msg string msg string
sql string sql string
code res_t code res_t
@@ -29,19 +27,13 @@ func (e *Error) Code() ErrorCode {
// //
// https://sqlite.org/rescode.html // https://sqlite.org/rescode.html
func (e *Error) ExtendedCode() ExtendedErrorCode { func (e *Error) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e.code) return xErrorCode(e.code)
} }
// Error implements the error interface. // Error implements the error interface.
func (e *Error) Error() string { func (e *Error) Error() string {
var b strings.Builder var b strings.Builder
b.WriteString("sqlite3: ") b.WriteString(util.ErrorCodeString(uint32(e.code)))
if e.str != "" {
b.WriteString(e.str)
} else {
b.WriteString(strconv.Itoa(int(e.code)))
}
if e.msg != "" { if e.msg != "" {
b.WriteString(": ") b.WriteString(": ")
@@ -103,12 +95,12 @@ func (e ErrorCode) Error() string {
// Temporary returns true for [BUSY] errors. // Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool { func (e ErrorCode) Temporary() bool {
return e == BUSY return e == BUSY || e == INTERRUPT
} }
// ExtendedCode returns the extended error code for this error. // ExtendedCode returns the extended error code for this error.
func (e ErrorCode) ExtendedCode() ExtendedErrorCode { func (e ErrorCode) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e) return xErrorCode(e)
} }
// Error implements the error interface. // Error implements the error interface.
@@ -133,7 +125,7 @@ func (e ExtendedErrorCode) As(err any) bool {
// Temporary returns true for [BUSY] errors. // Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool { func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT
} }
// Timeout returns true for [BUSY_TIMEOUT] errors. // Timeout returns true for [BUSY_TIMEOUT] errors.

View File

@@ -43,7 +43,7 @@ func TestError(t *testing.T) {
if !errors.Is(err, xErrorCode(0x8080)) { if !errors.Is(err, xErrorCode(0x8080)) {
t.Errorf("want true") t.Errorf("want true")
} }
if s := err.Error(); s != "sqlite3: 32896" { if s := err.Error(); s != "sqlite3: unknown error" {
t.Errorf("got %q", s) t.Errorf("got %q", s)
} }
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) { if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
@@ -83,7 +83,7 @@ func TestError_Temporary(t *testing.T) {
} }
} }
{ {
err := ExtendedErrorCode(tt.code) err := xErrorCode(tt.code)
if got := err.Temporary(); got != tt.want { if got := err.Temporary(); got != tt.want {
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want) t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
} }
@@ -115,7 +115,7 @@ func TestError_Timeout(t *testing.T) {
} }
} }
{ {
err := ExtendedErrorCode(tt.code) err := xErrorCode(tt.code)
if got := err.Timeout(); got != tt.want { if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want) t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
} }
@@ -156,12 +156,12 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
defer db.Close() defer db.Close()
// Test all extended error codes. // Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ { for i := 0; i == int(xErrorCode(i)); i++ {
want := "sqlite3: " want := "sqlite3: "
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i))) ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
want += util.ReadString(db.mod, ptr, _MAX_NAME) want += util.ReadString(db.mod, ptr, _MAX_NAME)
got := ExtendedErrorCode(i).Error() got := xErrorCode(i).Error()
if got != want { if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i) t.Fatalf("got %q, want %q, with %d", got, want, i)
} }

View File

@@ -75,7 +75,7 @@ func ErrorCodeString(rc uint32) string {
return "sqlite3: unable to open database file" return "sqlite3: unable to open database file"
case PROTOCOL: case PROTOCOL:
return "sqlite3: locking protocol" return "sqlite3: locking protocol"
case FORMAT: case EMPTY:
break break
case SCHEMA: case SCHEMA:
return "sqlite3: database schema has changed" return "sqlite3: database schema has changed"
@@ -91,7 +91,7 @@ func ErrorCodeString(rc uint32) string {
break break
case AUTH: case AUTH:
return "sqlite3: authorization denied" return "sqlite3: authorization denied"
case EMPTY: case FORMAT:
break break
case RANGE: case RANGE:
return "sqlite3: column index out of range" return "sqlite3: column index out of range"

View File

@@ -135,11 +135,10 @@ func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string {
panic(RangeErr) panic(RangeErr)
} }
} }
if i := bytes.IndexByte(buf, 0); i < 0 { if i := bytes.IndexByte(buf, 0); i >= 0 {
panic(NoNulErr)
} else {
return string(buf[:i]) return string(buf[:i])
} }
panic(NoNulErr)
} }
func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) { func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) {

View File

@@ -120,33 +120,33 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
return nil return nil
} }
err := Error{code: rc} if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
panic(util.OOMErr) panic(util.OOMErr)
} }
if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 {
err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME)
}
if handle != 0 { if handle != 0 {
var msg, query string
if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 { if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH) msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
switch {
case msg == "not an error":
msg = ""
case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
msg = ""
}
} }
if len(sql) != 0 { if len(sql) != 0 {
if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 { if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
err.sql = sql[0][i:] query = sql[0][i:]
} }
} }
}
switch err.msg { if msg != "" || query != "" {
case err.str, "not an error": return &Error{code: rc, msg: msg, sql: query}
err.msg = "" }
} }
return &err return xErrorCode(rc)
} }
func (sqlt *sqlite) getfn(name string) api.Function { func (sqlt *sqlite) getfn(name string) api.Function {