From 254d47354695c420abc0686a436ad59a8a000b49 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Thu, 25 May 2023 15:46:15 +0100 Subject: [PATCH] VFS URI parameters. --- embed/exports.txt | 10 ++++--- internal/util/bool.go | 22 ++++++++++++++ internal/util/bool_test.go | 28 +++++++++++++++++ module.go | 12 ++++---- sqlite3vfs/api.go | 11 +++++++ sqlite3vfs/vfs.go | 61 +++++++++++++++++++++++++++++++++++++- 6 files changed, 133 insertions(+), 11 deletions(-) create mode 100644 internal/util/bool.go create mode 100644 internal/util/bool_test.go diff --git a/embed/exports.txt b/embed/exports.txt index 82974bf..2a07fc0 100644 --- a/embed/exports.txt +++ b/embed/exports.txt @@ -37,11 +37,13 @@ sqlite3_blob_bytes sqlite3_blob_read sqlite3_blob_write sqlite3_blob_reopen -sqlite3_get_autocommit -sqlite3_last_insert_rowid -sqlite3_changes64 sqlite3_backup_init sqlite3_backup_step sqlite3_backup_finish sqlite3_backup_remaining -sqlite3_backup_pagecount \ No newline at end of file +sqlite3_backup_pagecount +sqlite3_uri_parameter +sqlite3_uri_key +sqlite3_changes64 +sqlite3_last_insert_rowid +sqlite3_get_autocommit \ No newline at end of file diff --git a/internal/util/bool.go b/internal/util/bool.go new file mode 100644 index 0000000..8427f30 --- /dev/null +++ b/internal/util/bool.go @@ -0,0 +1,22 @@ +package util + +import "strings" + +func ParseBool(s string) (b, ok bool) { + if len(s) == 0 { + return false, false + } + if s[0] == '0' { + return false, true + } + if '1' <= s[0] && s[0] <= '9' { + return true, true + } + switch strings.ToLower(s) { + case "true", "yes", "on": + return true, true + case "false", "no", "off": + return false, true + } + return false, false +} diff --git a/internal/util/bool_test.go b/internal/util/bool_test.go new file mode 100644 index 0000000..b89b2e8 --- /dev/null +++ b/internal/util/bool_test.go @@ -0,0 +1,28 @@ +package util + +import "testing" + +func TestParseBool(t *testing.T) { + tests := []struct { + str string + val bool + ok bool + }{ + {"", false, false}, + {"0", false, true}, + {"1", true, true}, + {"9", true, true}, + {"T", false, false}, + {"true", true, true}, + {"FALSE", false, true}, + {"false?", false, false}, + } + for _, tt := range tests { + t.Run(tt.str, func(t *testing.T) { + gotVal, gotOK := ParseBool(tt.str) + if gotVal != tt.val || gotOK != tt.ok { + t.Errorf("ParseBool(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok) + } + }) + } +} diff --git a/module.go b/module.go index 5871247..70d6270 100644 --- a/module.go +++ b/module.go @@ -139,9 +139,6 @@ func newModule(mod api.Module) (m *module, err error) { columnText: getFun("sqlite3_column_text"), columnBlob: getFun("sqlite3_column_blob"), columnBytes: getFun("sqlite3_column_bytes"), - autocommit: getFun("sqlite3_get_autocommit"), - lastRowid: getFun("sqlite3_last_insert_rowid"), - changes: getFun("sqlite3_changes64"), blobOpen: getFun("sqlite3_blob_open"), blobClose: getFun("sqlite3_blob_close"), blobReopen: getFun("sqlite3_blob_reopen"), @@ -153,6 +150,9 @@ func newModule(mod api.Module) (m *module, err error) { backupFinish: getFun("sqlite3_backup_finish"), backupRemaining: getFun("sqlite3_backup_remaining"), backupPageCount: getFun("sqlite3_backup_pagecount"), + changes: getFun("sqlite3_changes64"), + lastRowid: getFun("sqlite3_last_insert_rowid"), + autocommit: getFun("sqlite3_get_autocommit"), } if err != nil { return nil, err @@ -334,9 +334,6 @@ type sqliteAPI struct { columnText api.Function columnBlob api.Function columnBytes api.Function - autocommit api.Function - lastRowid api.Function - changes api.Function blobOpen api.Function blobClose api.Function blobReopen api.Function @@ -348,5 +345,8 @@ type sqliteAPI struct { backupFinish api.Function backupRemaining api.Function backupPageCount api.Function + changes api.Function + lastRowid api.Function + autocommit api.Function destructor uint32 } diff --git a/sqlite3vfs/api.go b/sqlite3vfs/api.go index 7bdc3b9..265384c 100644 --- a/sqlite3vfs/api.go +++ b/sqlite3vfs/api.go @@ -1,6 +1,8 @@ // Package sqlite3vfs wraps the C SQLite VFS API. package sqlite3vfs +import "net/url" + // A VFS defines the interface between the SQLite core and the underlying operating system. // // Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite. @@ -13,6 +15,15 @@ type VFS interface { FullPathname(name string) (string, error) } +// VFSParams extends [VFS] to with the ability to handle URI parameters +// through the OpenParams method. +// +// https://www.sqlite.org/c3ref/uri_boolean.html +type VFSParams interface { + VFS + OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error) +} + // A File represents an open file in the OS interface layer. // // Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite. diff --git a/sqlite3vfs/vfs.go b/sqlite3vfs/vfs.go index 5e79aef..b568957 100644 --- a/sqlite3vfs/vfs.go +++ b/sqlite3vfs/vfs.go @@ -4,6 +4,7 @@ import ( "context" "crypto/rand" "io" + "net/url" "reflect" "time" @@ -172,7 +173,25 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla path = util.ReadString(mod, zPath, _MAX_PATHNAME) } - file, flags, err := vfs.Open(path, flags) + var file File + var err error + var params url.Values + if pfs, ok := vfs.(VFSParams); ok { + params = vfsURIParameters(ctx, mod, zPath, flags) + file, flags, err = pfs.OpenParams(path, flags, params) + } else { + file, flags, err = vfs.Open(path, flags) + } + + if file, ok := file.(FilePowersafeOverwrite); ok { + if params == nil { + params = vfsURIParameters(ctx, mod, zPath, flags) + } + if b, ok := util.ParseBool(params.Get("psow")); ok { + file.SetPowersafeOverwrite(b) + } + } + if err != nil { return vfsErrorCode(err, _CANTOPEN) } @@ -339,6 +358,46 @@ func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) return file.DeviceCharacteristics() } +func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags OpenFlag) url.Values { + if flags&OPEN_URI == 0 { + return nil + } + + uriParam := mod.ExportedFunction("sqlite3_uri_parameter") + uriKey := mod.ExportedFunction("sqlite3_uri_key") + if uriParam == nil || uriKey == nil { + return nil + } + + var stack [2]uint64 + var params url.Values + + for i := 0; ; i++ { + stack[1] = uint64(i) + stack[0] = uint64(zPath) + if err := uriKey.CallWithStack(ctx, stack[:]); err != nil { + panic(err) + } + if stack[0] == 0 { + return params + } + key := util.ReadString(mod, uint32(stack[0]), _MAX_STRING) + if params.Has(key) { + continue + } + + stack[1] = stack[0] + stack[0] = uint64(zPath) + if err := uriParam.CallWithStack(ctx, stack[:]); err != nil { + panic(err) + } + if params == nil { + params = url.Values{} + } + params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_STRING)) + } +} + func vfsGet(mod api.Module, pVfs uint32) VFS { if pVfs == 0 { return vfsOS{}