diff --git a/sqlite3vfs/vfs.go b/sqlite3vfs/vfs.go index ef46850..be87025 100644 --- a/sqlite3vfs/vfs.go +++ b/sqlite3vfs/vfs.go @@ -15,6 +15,8 @@ import ( // ExportHostFunctions registers the required VFS host functions // with the provided env module. +// +// Users of the [github.com/ncruces/go-sqlite3] package need not call this directly. func ExportHostFunctions(env wazero.HostModuleBuilder) wazero.HostModuleBuilder { util.ExportFuncII(env, "go_vfs_find", vfsFind) util.ExportFuncIIJ(env, "go_localtime", vfsLocaltime) @@ -53,6 +55,8 @@ type vfsState struct { // // The returned [io.Closer] should be closed after the [api.Module] is closed, // to release any associated resources. +// +// Users of the [github.com/ncruces/go-sqlite3] package need not call this directly. func NewContext(ctx context.Context) (context.Context, io.Closer) { vfs := &vfsState{} return context.WithValue(ctx, vfsKey{}, vfs), vfs diff --git a/sqlite3vfs/vfs_api.go b/sqlite3vfs/vfs_api.go index 7849b58..6727910 100644 --- a/sqlite3vfs/vfs_api.go +++ b/sqlite3vfs/vfs_api.go @@ -1,12 +1,12 @@ // Package sqlite3vfs wraps the C SQLite VFS API. package sqlite3vfs -import ( - "sync" -) +import "sync" // A VFS defines the interface between the SQLite core and the underlying operating system. // +// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes. +// // https://www.sqlite.org/c3ref/vfs.html type VFS interface { Open(name string, flags OpenFlag) (File, OpenFlag, error) @@ -17,7 +17,9 @@ type VFS interface { // A File represents an open file in the OS interface layer. // -// https://www.sqlite.org/c3ref/file.html +// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes. +// In particular, sqlite3.BUSY is necessary to correctly implement lock methods. +// // https://www.sqlite.org/c3ref/io_methods.html type File interface { Close() error diff --git a/sqlite3vfs/vfs_lock_test.go b/sqlite3vfs/vfs_lock_test.go index cbd0a6e..7b98574 100644 --- a/sqlite3vfs/vfs_lock_test.go +++ b/sqlite3vfs/vfs_lock_test.go @@ -60,6 +60,13 @@ func Test_vfsLock(t *testing.T) { if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } + rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) { + t.Error("invalid lock state", got) + } rc = vfsLock(ctx, mod, pFile2, LOCK_SHARED) if rc != _OK { @@ -80,6 +87,13 @@ func Test_vfsLock(t *testing.T) { if got := util.ReadUint32(mod, pOutput); got != 0 { t.Error("file was locked") } + rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) { + t.Error("invalid lock state", got) + } rc = vfsLock(ctx, mod, pFile2, LOCK_RESERVED) if rc != _OK { @@ -104,6 +118,13 @@ func Test_vfsLock(t *testing.T) { if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } + rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_RESERVED) { + t.Error("invalid lock state", got) + } rc = vfsLock(ctx, mod, pFile2, LOCK_EXCLUSIVE) if rc != _OK { @@ -124,6 +145,13 @@ func Test_vfsLock(t *testing.T) { if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } + rc = vfsFileControl(ctx, mod, pFile2, _FCNTL_LOCKSTATE, pOutput) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_EXCLUSIVE) { + t.Error("invalid lock state", got) + } rc = vfsLock(ctx, mod, pFile1, LOCK_SHARED) if rc == _OK { @@ -144,6 +172,13 @@ func Test_vfsLock(t *testing.T) { if got := util.ReadUint32(mod, pOutput); got == 0 { t.Error("file wasn't locked") } + rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_NONE) { + t.Error("invalid lock state", got) + } rc = vfsUnlock(ctx, mod, pFile2, LOCK_SHARED) if rc != _OK { @@ -169,4 +204,16 @@ func Test_vfsLock(t *testing.T) { if rc != _OK { t.Fatal("returned", rc) } + rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCKSTATE, pOutput) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, pOutput); got != uint32(LOCK_SHARED) { + t.Error("invalid lock state", got) + } + + rc = vfsFileControl(ctx, mod, pFile1, _FCNTL_LOCK_TIMEOUT, 1) + if rc != _OK { + t.Fatal("returned", rc) + } } diff --git a/sqlite3vfs/vfs_os_unix.go b/sqlite3vfs/vfs_os_unix.go index 8ed788a..3c2194f 100644 --- a/sqlite3vfs/vfs_os_unix.go +++ b/sqlite3vfs/vfs_os_unix.go @@ -28,7 +28,7 @@ func osAccess(path string, flags AccessFlag) error { func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode { // Test the PENDING lock before acquiring a new SHARED lock. if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending { - return _ErrorCode(_BUSY) + return _BUSY } // Acquire the SHARED lock. return osReadLock(file, _SHARED_FIRST, _SHARED_SIZE, timeout) @@ -78,9 +78,9 @@ func osLockErrorCode(err error, def _ErrorCode) _ErrorCode { unix.ENOLCK, unix.EDEADLK, unix.ETIMEDOUT: - return _ErrorCode(_BUSY) + return _BUSY case unix.EPERM: - return _ErrorCode(_PERM) + return _PERM } } return def diff --git a/sqlite3vfs/vfs_test.go b/sqlite3vfs/vfs_test.go index 061526a..a606930 100644 --- a/sqlite3vfs/vfs_test.go +++ b/sqlite3vfs/vfs_test.go @@ -5,6 +5,7 @@ import ( "context" "errors" "io/fs" + "math" "os" "path/filepath" "syscall" @@ -218,6 +219,11 @@ func Test_vfsFile(t *testing.T) { t.Fatal("returned", rc) } + // Check sector size. + if size := vfsSectorSize(ctx, mod, 4); size != _DEFAULT_SECTOR_SIZE { + t.Fatal("returned", size) + } + // Write stuff. text := "Hello world!" util.WriteString(mod, 16, text) @@ -274,3 +280,65 @@ func Test_vfsFile(t *testing.T) { t.Fatal("returned", rc) } } + +func Test_vfsFile_psow(t *testing.T) { + mod := util.NewMockModule(128) + ctx, vfs := NewContext(context.TODO()) + defer vfs.Close() + + // Open a temporary file. + rc := vfsOpen(ctx, mod, 0, 0, 4, OPEN_CREATE|OPEN_EXCLUSIVE|OPEN_READWRITE|OPEN_DELETEONCLOSE, 0) + if rc != _OK { + t.Fatal("returned", rc) + } + + // Read powersafe overwrite. + util.WriteUint32(mod, 16, math.MaxUint32) + rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, 16); got == 0 { + t.Error("psow disabled") + } + + // Unset powersafe overwrite. + util.WriteUint32(mod, 16, 0) + rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) + if rc != _OK { + t.Fatal("returned", rc) + } + + // Read powersafe overwrite. + util.WriteUint32(mod, 16, math.MaxUint32) + rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, 16); got != 0 { + t.Error("psow enabled") + } + + // Set powersafe overwrite. + util.WriteUint32(mod, 16, 1) + rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) + if rc != _OK { + t.Fatal("returned", rc) + } + + // Read powersafe overwrite. + util.WriteUint32(mod, 16, math.MaxUint32) + rc = vfsFileControl(ctx, mod, 4, _FCNTL_POWERSAFE_OVERWRITE, 16) + if rc != _OK { + t.Fatal("returned", rc) + } + if got := util.ReadUint32(mod, 16); got == 0 { + t.Error("psow disabled") + } + + // Close the file. + rc = vfsClose(ctx, mod, 4) + if rc != _OK { + t.Fatal("returned", rc) + } +} diff --git a/tests/conn_test.go b/tests/conn_test.go index a5baf1a..f6848d5 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -3,6 +3,7 @@ package tests import ( "context" "errors" + "reflect" "strings" "testing" @@ -58,6 +59,41 @@ func TestConn_Close_BUSY(t *testing.T) { } } +func TestConn_Pragma(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open("file::memory:?_pragma=busy_timeout(1000)") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + got, err := db.Pragma("busy_timeout") + if err != nil { + t.Fatal(err) + } + want := []string{"1000"} + + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v, want %v", got, want) + } + + var serr *sqlite3.Error + _, err = db.Pragma("+") + if err == nil { + t.Error("want: error") + } + if !errors.As(err, &serr) { + t.Fatalf("got %T, want sqlite3.Error", err) + } + if rc := serr.Code(); rc != sqlite3.ERROR { + t.Errorf("got %d, want sqlite3.ERROR", rc) + } + if got := err.Error(); got != `sqlite3: SQL logic error: near "+": syntax error` { + t.Error("got message:", got) + } +} + func TestConn_SetInterrupt(t *testing.T) { t.Parallel()