diff --git a/context.go b/context.go index 7e7f252..7e5a974 100644 --- a/context.go +++ b/context.go @@ -1,6 +1,7 @@ package sqlite3 import ( + "errors" "math" "time" @@ -119,3 +120,37 @@ func (c *Context) resultRFC3339Nano(value time.Time) { uint64(c.handle), uint64(ptr), uint64(len(buf)), uint64(c.c.api.destructor), _UTF8) } + +// ResultError sets the result of the function an error. +// +// https://www.sqlite.org/c3ref/result_blob.html +func (c *Context) ResultError(err error) { + if errors.Is(err, NOMEM) { + c.c.call(c.c.api.resultErrorMem, uint64(c.handle)) + return + } + + if errors.Is(err, TOOBIG) { + c.c.call(c.c.api.resultErrorBig, uint64(c.handle)) + return + } + + str := err.Error() + ptr := c.c.arena.string(str) + c.c.call(c.c.api.resultBlob, + uint64(c.handle), uint64(ptr), uint64(len(str))) + + var code uint64 + var ecode ErrorCode + var xcode xErrorCode + switch { + case errors.As(err, &xcode): + code = uint64(xcode) + case errors.As(err, &ecode): + code = uint64(ecode) + } + if code != 0 { + c.c.call(c.c.api.resultErrorCode, + uint64(c.handle), uint64(xcode)) + } +} diff --git a/error.go b/error.go index 957a744..c91dccd 100644 --- a/error.go +++ b/error.go @@ -68,6 +68,19 @@ func (e *Error) Is(err error) bool { return false } +// As converts this error to an [ErrorCode] or [ExtendedErrorCode]. +func (e *Error) As(err any) bool { + switch c := err.(type) { + case *ErrorCode: + *c = e.Code() + return true + case *ExtendedErrorCode: + *c = e.ExtendedCode() + return true + } + return false +} + // Temporary returns true for [BUSY] errors. func (e *Error) Temporary() bool { return e.Code() == BUSY @@ -104,6 +117,15 @@ func (e ExtendedErrorCode) Is(err error) bool { return ok && c == ErrorCode(e) } +// As converts this error to an [ErrorCode]. +func (e ExtendedErrorCode) As(err any) bool { + c, ok := err.(*ErrorCode) + if ok { + *c = ErrorCode(e) + } + return ok +} + // Temporary returns true for [BUSY] errors. func (e ExtendedErrorCode) Temporary() bool { return ErrorCode(e) == BUSY diff --git a/error_test.go b/error_test.go index 4cfe7a6..d4b9037 100644 --- a/error_test.go +++ b/error_test.go @@ -18,22 +18,36 @@ func Test_assertErr(t *testing.T) { func TestError(t *testing.T) { t.Parallel() - err := Error{code: 0x8080} - if rc := err.Code(); rc != 0x80 { - t.Errorf("got %#x, want 0x80", rc) + var ecode ErrorCode + var xcode xErrorCode + err := &Error{code: 0x8080} + if !errors.As(err, &err) { + t.Fatal("want true") } - if !errors.Is(&err, ErrorCode(0x80)) { + if ecode := err.Code(); ecode != 0x80 { + t.Errorf("got %#x, want 0x80", uint8(ecode)) + } + if ok := errors.As(err, &ecode); !ok || ecode != ErrorCode(0x80) { + t.Errorf("got %#x, want 0x80", uint8(ecode)) + } + if !errors.Is(err, ErrorCode(0x80)) { t.Errorf("want true") } - if rc := err.ExtendedCode(); rc != 0x8080 { - t.Errorf("got %#x, want 0x8080", rc) + if xcode := err.ExtendedCode(); xcode != 0x8080 { + t.Errorf("got %#x, want 0x8080", uint16(xcode)) } - if !errors.Is(&err, ExtendedErrorCode(0x8080)) { + if ok := errors.As(err, &xcode); !ok || xcode != xErrorCode(0x8080) { + t.Errorf("got %#x, want 0x8080", uint16(xcode)) + } + if !errors.Is(err, xErrorCode(0x8080)) { t.Errorf("want true") } if s := err.Error(); s != "sqlite3: 32896" { t.Errorf("got %q", s) } + if ok := errors.As(err.ExtendedCode(), &ecode); !ok || ecode != ErrorCode(0x80) { + t.Errorf("got %#x, want 0x80", uint8(ecode)) + } if !errors.Is(err.ExtendedCode(), ErrorCode(0x80)) { t.Errorf("want true") }