VFS URI parameters.

This commit is contained in:
Nuno Cruces
2023-05-25 15:46:15 +01:00
parent 5639fc1ff8
commit 254d473546
6 changed files with 133 additions and 11 deletions

View File

@@ -37,11 +37,13 @@ sqlite3_blob_bytes
sqlite3_blob_read
sqlite3_blob_write
sqlite3_blob_reopen
sqlite3_get_autocommit
sqlite3_last_insert_rowid
sqlite3_changes64
sqlite3_backup_init
sqlite3_backup_step
sqlite3_backup_finish
sqlite3_backup_remaining
sqlite3_backup_pagecount
sqlite3_backup_pagecount
sqlite3_uri_parameter
sqlite3_uri_key
sqlite3_changes64
sqlite3_last_insert_rowid
sqlite3_get_autocommit

22
internal/util/bool.go Normal file
View File

@@ -0,0 +1,22 @@
package util
import "strings"
func ParseBool(s string) (b, ok bool) {
if len(s) == 0 {
return false, false
}
if s[0] == '0' {
return false, true
}
if '1' <= s[0] && s[0] <= '9' {
return true, true
}
switch strings.ToLower(s) {
case "true", "yes", "on":
return true, true
case "false", "no", "off":
return false, true
}
return false, false
}

View File

@@ -0,0 +1,28 @@
package util
import "testing"
func TestParseBool(t *testing.T) {
tests := []struct {
str string
val bool
ok bool
}{
{"", false, false},
{"0", false, true},
{"1", true, true},
{"9", true, true},
{"T", false, false},
{"true", true, true},
{"FALSE", false, true},
{"false?", false, false},
}
for _, tt := range tests {
t.Run(tt.str, func(t *testing.T) {
gotVal, gotOK := ParseBool(tt.str)
if gotVal != tt.val || gotOK != tt.ok {
t.Errorf("ParseBool(%q) = (%v, %v) want (%v, %v)", tt.str, gotVal, gotOK, tt.val, tt.ok)
}
})
}
}

View File

@@ -139,9 +139,6 @@ func newModule(mod api.Module) (m *module, err error) {
columnText: getFun("sqlite3_column_text"),
columnBlob: getFun("sqlite3_column_blob"),
columnBytes: getFun("sqlite3_column_bytes"),
autocommit: getFun("sqlite3_get_autocommit"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
changes: getFun("sqlite3_changes64"),
blobOpen: getFun("sqlite3_blob_open"),
blobClose: getFun("sqlite3_blob_close"),
blobReopen: getFun("sqlite3_blob_reopen"),
@@ -153,6 +150,9 @@ func newModule(mod api.Module) (m *module, err error) {
backupFinish: getFun("sqlite3_backup_finish"),
backupRemaining: getFun("sqlite3_backup_remaining"),
backupPageCount: getFun("sqlite3_backup_pagecount"),
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
}
if err != nil {
return nil, err
@@ -334,9 +334,6 @@ type sqliteAPI struct {
columnText api.Function
columnBlob api.Function
columnBytes api.Function
autocommit api.Function
lastRowid api.Function
changes api.Function
blobOpen api.Function
blobClose api.Function
blobReopen api.Function
@@ -348,5 +345,8 @@ type sqliteAPI struct {
backupFinish api.Function
backupRemaining api.Function
backupPageCount api.Function
changes api.Function
lastRowid api.Function
autocommit api.Function
destructor uint32
}

View File

@@ -1,6 +1,8 @@
// Package sqlite3vfs wraps the C SQLite VFS API.
package sqlite3vfs
import "net/url"
// A VFS defines the interface between the SQLite core and the underlying operating system.
//
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite.
@@ -13,6 +15,15 @@ type VFS interface {
FullPathname(name string) (string, error)
}
// VFSParams extends [VFS] to with the ability to handle URI parameters
// through the OpenParams method.
//
// https://www.sqlite.org/c3ref/uri_boolean.html
type VFSParams interface {
VFS
OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error)
}
// A File represents an open file in the OS interface layer.
//
// Use sqlite3.ErrorCode or sqlite3.ExtendedErrorCode to return specific error codes to SQLite.

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"io"
"net/url"
"reflect"
"time"
@@ -172,7 +173,25 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
path = util.ReadString(mod, zPath, _MAX_PATHNAME)
}
file, flags, err := vfs.Open(path, flags)
var file File
var err error
var params url.Values
if pfs, ok := vfs.(VFSParams); ok {
params = vfsURIParameters(ctx, mod, zPath, flags)
file, flags, err = pfs.OpenParams(path, flags, params)
} else {
file, flags, err = vfs.Open(path, flags)
}
if file, ok := file.(FilePowersafeOverwrite); ok {
if params == nil {
params = vfsURIParameters(ctx, mod, zPath, flags)
}
if b, ok := util.ParseBool(params.Get("psow")); ok {
file.SetPowersafeOverwrite(b)
}
}
if err != nil {
return vfsErrorCode(err, _CANTOPEN)
}
@@ -339,6 +358,46 @@ func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32)
return file.DeviceCharacteristics()
}
func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags OpenFlag) url.Values {
if flags&OPEN_URI == 0 {
return nil
}
uriParam := mod.ExportedFunction("sqlite3_uri_parameter")
uriKey := mod.ExportedFunction("sqlite3_uri_key")
if uriParam == nil || uriKey == nil {
return nil
}
var stack [2]uint64
var params url.Values
for i := 0; ; i++ {
stack[1] = uint64(i)
stack[0] = uint64(zPath)
if err := uriKey.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
if stack[0] == 0 {
return params
}
key := util.ReadString(mod, uint32(stack[0]), _MAX_STRING)
if params.Has(key) {
continue
}
stack[1] = stack[0]
stack[0] = uint64(zPath)
if err := uriParam.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
if params == nil {
params = url.Values{}
}
params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_STRING))
}
}
func vfsGet(mod api.Module, pVfs uint32) VFS {
if pVfs == 0 {
return vfsOS{}