This commit is contained in:
Nuno Cruces
2024-07-23 13:28:09 +01:00
parent 06f58c35e3
commit 24288c0e26
5 changed files with 55 additions and 7 deletions

View File

@@ -2,6 +2,7 @@ package sqlite3
import (
"errors"
"fmt"
"strings"
"testing"
@@ -10,7 +11,7 @@ import (
func Test_assertErr(t *testing.T) {
err := util.AssertErr()
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:12)") {
if s := err.Error(); !strings.HasPrefix(s, "sqlite3: assertion failed") || !strings.HasSuffix(s, "error_test.go:13)") {
t.Errorf("got %q", s)
}
}
@@ -166,3 +167,32 @@ func Test_ExtendedErrorCode_Error(t *testing.T) {
}
}
}
func Test_errorCode(t *testing.T) {
tests := []struct {
arg error
wantMsg string
wantCode uint32
}{
{nil, "", _OK},
{ERROR, "", util.ERROR},
{IOERR, "", util.IOERR},
{IOERR_READ, "", util.IOERR_READ},
{&Error{code: util.ERROR}, "", util.ERROR},
{fmt.Errorf("%w", ERROR), ERROR.Error(), util.ERROR},
{fmt.Errorf("%w", IOERR), IOERR.Error(), util.IOERR},
{fmt.Errorf("%w", IOERR_READ), IOERR_READ.Error(), util.IOERR_READ},
{fmt.Errorf("error"), "error", util.ERROR},
}
for _, tt := range tests {
t.Run("", func(t *testing.T) {
gotMsg, gotCode := errorCode(tt.arg, ERROR)
if gotMsg != tt.wantMsg {
t.Errorf("errorCode() gotMsg = %q, want %q", gotMsg, tt.wantMsg)
}
if gotCode != uint32(tt.wantCode) {
t.Errorf("errorCode() gotCode = %d, want %d", gotCode, tt.wantCode)
}
})
}
}

View File

@@ -129,10 +129,10 @@ func connect(db *sqlite3.Conn, _, schema, table string, arg ...string) (_ *bloom
defer load.Close()
if !load.Step() {
if err = load.Err(); err == nil {
err = sqlite3.CORRUPT_VTAB
if err := load.Err(); err != nil {
return nil, err
}
return nil, err
return nil, sqlite3.CORRUPT_VTAB
}
t.bytes = load.ColumnInt64(0)

View File

@@ -123,12 +123,11 @@ func (t *table) BestIndex(idx *sqlite3.IndexInfo) error {
return nil
}
func (t *table) Open() (sqlite3.VTabCursor, error) {
func (t *table) Open() (_ sqlite3.VTabCursor, err error) {
stmt := t.stmt
if !t.inuse {
t.inuse = true
} else {
var err error
stmt, _, err = t.stmt.Conn().Prepare(t.sql)
if err != nil {
return nil, err

View File

@@ -5,6 +5,7 @@ package vfs
import (
"io"
"os"
"runtime"
"time"
"golang.org/x/sys/unix"
@@ -68,7 +69,7 @@ func osUnlock(file *os.File, start, len int64) _ErrorCode {
}
func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, def _ErrorCode) _ErrorCode {
lock := flocktimeout_t{fl: unix.Flock_t{
lock := &flocktimeout_t{fl: unix.Flock_t{
Type: typ,
Start: start,
Len: len,
@@ -82,6 +83,7 @@ func osLock(file *os.File, typ int16, start, len int64, timeout time.Duration, d
default:
lock.timeout = unix.NsecToTimespec(int64(timeout / time.Nanosecond))
err = unix.FcntlFlock(file.Fd(), _F_OFD_SETLKWTIMEOUT, &lock.fl)
runtime.KeepAlive(lock)
}
return osLockErrorCode(err, def)
}

View File

@@ -45,3 +45,20 @@ func TestRegister(t *testing.T) {
t.Error("want skip")
}
func TestRegister_os(t *testing.T) {
os := vfs.Find("os")
if os == nil {
t.Fail()
}
vfs.Register("os", testVFS{t})
if vfs.Find("os") != os {
t.Fail()
}
vfs.Unregister("os")
if vfs.Find("os") != os {
t.Fail()
}
}