Optimize errors a bit.

This commit is contained in:
Nuno Cruces
2025-03-28 11:51:45 +00:00
parent e5c285b783
commit 41dc46af7e
5 changed files with 28 additions and 37 deletions

View File

@@ -2,7 +2,6 @@ package sqlite3
import (
"errors"
"strconv"
"strings"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -12,7 +11,6 @@ import (
//
// https://sqlite.org/c3ref/errcode.html
type Error struct {
str string
msg string
sql string
code res_t
@@ -29,19 +27,13 @@ func (e *Error) Code() ErrorCode {
//
// https://sqlite.org/rescode.html
func (e *Error) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e.code)
return xErrorCode(e.code)
}
// Error implements the error interface.
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("sqlite3: ")
if e.str != "" {
b.WriteString(e.str)
} else {
b.WriteString(strconv.Itoa(int(e.code)))
}
b.WriteString(util.ErrorCodeString(uint32(e.code)))
if e.msg != "" {
b.WriteString(": ")
@@ -103,12 +95,12 @@ func (e ErrorCode) Error() string {
// Temporary returns true for [BUSY] errors.
func (e ErrorCode) Temporary() bool {
return e == BUSY
return e == BUSY || e == INTERRUPT
}
// ExtendedCode returns the extended error code for this error.
func (e ErrorCode) ExtendedCode() ExtendedErrorCode {
return ExtendedErrorCode(e)
return xErrorCode(e)
}
// Error implements the error interface.
@@ -133,7 +125,7 @@ func (e ExtendedErrorCode) As(err any) bool {
// Temporary returns true for [BUSY] errors.
func (e ExtendedErrorCode) Temporary() bool {
return ErrorCode(e) == BUSY
return ErrorCode(e) == BUSY || ErrorCode(e) == INTERRUPT
}
// Timeout returns true for [BUSY_TIMEOUT] errors.

View File

@@ -43,7 +43,7 @@ func TestError(t *testing.T) {
if !errors.Is(err, xErrorCode(0x8080)) {
t.Errorf("want true")
}
if s := err.Error(); s != "sqlite3: 32896" {
if s := err.Error(); s != "sqlite3: unknown error" {
t.Errorf("got %q", s)
}
if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) {
@@ -83,7 +83,7 @@ func TestError_Temporary(t *testing.T) {
}
}
{
err := ExtendedErrorCode(tt.code)
err := xErrorCode(tt.code)
if got := err.Temporary(); got != tt.want {
t.Errorf("ExtendedErrorCode.Temporary(%d) = %v, want %v", tt.code, got, tt.want)
}
@@ -115,7 +115,7 @@ func TestError_Timeout(t *testing.T) {
}
}
{
err := ExtendedErrorCode(tt.code)
err := xErrorCode(tt.code)
if got := err.Timeout(); got != tt.want {
t.Errorf("Error.Timeout(%d) = %v, want %v", tt.code, got, tt.want)
}
@@ -156,12 +156,12 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
defer db.Close()
// Test all extended error codes.
for i := 0; i == int(ExtendedErrorCode(i)); i++ {
for i := 0; i == int(xErrorCode(i)); i++ {
want := "sqlite3: "
ptr := ptr_t(db.call("sqlite3_errstr", stk_t(i)))
want += util.ReadString(db.mod, ptr, _MAX_NAME)
got := ExtendedErrorCode(i).Error()
got := xErrorCode(i).Error()
if got != want {
t.Fatalf("got %q, want %q, with %d", got, want, i)
}

View File

@@ -75,7 +75,7 @@ func ErrorCodeString(rc uint32) string {
return "sqlite3: unable to open database file"
case PROTOCOL:
return "sqlite3: locking protocol"
case FORMAT:
case EMPTY:
break
case SCHEMA:
return "sqlite3: database schema has changed"
@@ -91,7 +91,7 @@ func ErrorCodeString(rc uint32) string {
break
case AUTH:
return "sqlite3: authorization denied"
case EMPTY:
case FORMAT:
break
case RANGE:
return "sqlite3: column index out of range"

View File

@@ -135,11 +135,10 @@ func ReadString(mod api.Module, ptr Ptr_t, maxlen int64) string {
panic(RangeErr)
}
}
if i := bytes.IndexByte(buf, 0); i < 0 {
panic(NoNulErr)
} else {
if i := bytes.IndexByte(buf, 0); i >= 0 {
return string(buf[:i])
}
panic(NoNulErr)
}
func WriteBytes(mod api.Module, ptr Ptr_t, b []byte) {

View File

@@ -120,33 +120,33 @@ func (sqlt *sqlite) error(rc res_t, handle ptr_t, sql ...string) error {
return nil
}
err := Error{code: rc}
if err.Code() == NOMEM || err.ExtendedCode() == IOERR_NOMEM {
if ErrorCode(rc) == NOMEM || xErrorCode(rc) == IOERR_NOMEM {
panic(util.OOMErr)
}
if ptr := ptr_t(sqlt.call("sqlite3_errstr", stk_t(rc))); ptr != 0 {
err.str = util.ReadString(sqlt.mod, ptr, _MAX_NAME)
}
if handle != 0 {
var msg, query string
if ptr := ptr_t(sqlt.call("sqlite3_errmsg", stk_t(handle))); ptr != 0 {
err.msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
msg = util.ReadString(sqlt.mod, ptr, _MAX_LENGTH)
switch {
case msg == "not an error":
msg = ""
case msg == util.ErrorCodeString(uint32(rc))[len("sqlite3: "):]:
msg = ""
}
}
if len(sql) != 0 {
if i := int32(sqlt.call("sqlite3_error_offset", stk_t(handle))); i != -1 {
err.sql = sql[0][i:]
}
query = sql[0][i:]
}
}
switch err.msg {
case err.str, "not an error":
err.msg = ""
if msg != "" || query != "" {
return &Error{code: rc, msg: msg, sql: query}
}
return &err
}
return xErrorCode(rc)
}
func (sqlt *sqlite) getfn(name string) api.Function {