Filename API (#82)

Also remove VFSParams.
This commit is contained in:
Nuno Cruces
2024-04-23 11:43:14 +01:00
committed by GitHub
parent 7f6446ad31
commit 3fb0eeec51
9 changed files with 277 additions and 121 deletions

15
conn.go
View File

@@ -10,6 +10,7 @@ import (
"time"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/ncruces/go-sqlite3/vfs"
"github.com/tetratelabs/wazero/api"
)
@@ -216,6 +217,20 @@ func (c *Conn) DBName(n int) string {
return util.ReadString(c.mod, ptr, _MAX_NAME)
}
// Filename returns the filename for a database.
//
// https://sqlite.org/c3ref/db_filename.html
func (c *Conn) Filename(schema string) *vfs.Filename {
var ptr uint32
if schema != "" {
defer c.arena.mark()()
ptr = c.arena.string(schema)
}
r := c.call("sqlite3_db_filename", uint64(c.handle), uint64(ptr))
return vfs.OpenFilename(c.ctx, c.mod, uint32(r), vfs.OPEN_MAIN_DB)
}
// ReadOnly determines if a database is read-only.
//
// https://sqlite.org/c3ref/db_readonly.html

View File

@@ -53,6 +53,7 @@ sqlite3_create_module_go
sqlite3_create_window_function_go
sqlite3_database_file_object
sqlite3_db_config
sqlite3_db_filename
sqlite3_db_name
sqlite3_db_readonly
sqlite3_db_release_memory
@@ -62,6 +63,9 @@ sqlite3_errmsg
sqlite3_error_offset
sqlite3_errstr
sqlite3_exec
sqlite3_filename_database
sqlite3_filename_journal
sqlite3_filename_wal
sqlite3_finalize
sqlite3_get_autocommit
sqlite3_get_auxdata

Binary file not shown.

View File

@@ -416,6 +416,48 @@ func TestConn_SetLastInsertRowID(t *testing.T) {
}
}
func TestConn_Filename(t *testing.T) {
t.Parallel()
file := filepath.Join(t.TempDir(), "test.db")
db, err := sqlite3.Open(file)
if err != nil {
t.Fatal(err)
}
defer db.Close()
n := db.Filename("")
if n.String() != file {
t.Errorf("got %v", n)
}
if n.Database() != file {
t.Errorf("got %v", n)
}
if n.DatabaseFile() == nil {
t.Errorf("got %v", n)
}
n = db.Filename("xpto")
if n != nil {
t.Errorf("got %v", n)
}
if n.String() != "" {
t.Errorf("got %v", n)
}
if n.Database() != "" {
t.Errorf("got %v", n)
}
if n.Journal() != "" {
t.Errorf("got %v", n)
}
if n.WAL() != "" {
t.Errorf("got %v", n)
}
if n.DatabaseFile() != nil {
t.Errorf("got %v", n)
}
}
func TestConn_ReadOnly(t *testing.T) {
t.Parallel()

View File

@@ -4,7 +4,6 @@ import (
"encoding/binary"
"encoding/hex"
"io"
"net/url"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/internal/util"
@@ -18,38 +17,38 @@ type hbshVFS struct {
}
func (h *hbshVFS) Open(name string, flags vfs.OpenFlag) (vfs.File, vfs.OpenFlag, error) {
return h.OpenParams(name, flags, nil)
return nil, 0, sqlite3.CANTOPEN
}
func (h *hbshVFS) OpenParams(name string, flags vfs.OpenFlag, params url.Values) (file vfs.File, _ vfs.OpenFlag, err error) {
func (h *hbshVFS) OpenFilename(name *vfs.Filename, flags vfs.OpenFlag) (file vfs.File, _ vfs.OpenFlag, err error) {
var hbsh *hbsh.HBSH
// Encrypt everything except super journals.
if flags&vfs.OPEN_SUPER_JOURNAL == 0 {
var key []byte
if name == "" {
key = h.hbsh.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok {
key = h.hbsh.KDF(t[0])
}
if hbsh = h.hbsh.HBSH(key); hbsh == nil {
// Can't open without a valid key.
return nil, flags, sqlite3.CANTOPEN
if f, ok := name.DatabaseFile().(*hbshFile); ok {
hbsh = f.hbsh
} else {
var key []byte
if params := name.URIParameters(); name == nil {
key = h.hbsh.KDF("") // Temporary files get a random key.
} else if t, ok := params["key"]; ok {
key = []byte(t[0])
} else if t, ok := params["hexkey"]; ok {
key, _ = hex.DecodeString(t[0])
} else if t, ok := params["textkey"]; ok {
key = h.hbsh.KDF(t[0])
}
if hbsh = h.hbsh.HBSH(key); hbsh == nil {
// Can't open without a valid key.
return nil, flags, sqlite3.CANTOPEN
}
}
}
if h, ok := h.VFS.(vfs.VFSParams); ok {
delete(params, "key")
delete(params, "hexkey")
delete(params, "textkey")
file, flags, err = h.OpenParams(name, flags, params)
if h, ok := h.VFS.(vfs.VFSFilename); ok {
file, flags, err = h.OpenFilename(name, flags)
} else {
file, flags, err = h.Open(name, flags)
file, flags, err = h.Open(name.String(), flags)
}
if err != nil || hbsh == nil || flags&vfs.OPEN_MEMORY != 0 {
// Error, or no encryption (super journals, memory files).

View File

@@ -3,7 +3,6 @@ package vfs
import (
"context"
"net/url"
"github.com/tetratelabs/wazero/api"
)
@@ -20,22 +19,13 @@ type VFS interface {
FullPathname(name string) (string, error)
}
// VFSParams extends VFS with the ability to handle URI parameters
// through the OpenParams method.
// VFSFilename extends VFS with the ability to use Filename
// objects for opening files.
//
// https://sqlite.org/c3ref/uri_boolean.html
type VFSParams interface {
// https://sqlite.org/c3ref/filename.html
type VFSFilename interface {
VFS
OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error)
}
// VFSJournal extends VFS with the ability to open journals
// that need a reference to their corresponding database files.
//
// https://sqlite.org/c3ref/database_file_object.html
type VFSJournal interface {
VFS
OpenJournal(name string, flags OpenFlag, db File) (File, OpenFlag, error)
OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error)
}
// A File represents an open file in the OS interface layer.

View File

@@ -4,7 +4,6 @@ import (
"errors"
"io"
"io/fs"
"net/url"
"os"
"path/filepath"
"runtime"
@@ -70,10 +69,10 @@ func (vfsOS) Access(name string, flags AccessFlag) (bool, error) {
}
func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) {
return vfsOS{}.OpenParams(name, flags, nil)
return nil, 0, _CANTOPEN
}
func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error) {
func (vfsOS) OpenFilename(name *Filename, flags OpenFlag) (File, OpenFlag, error) {
var oflags int
if flags&OPEN_EXCLUSIVE != 0 {
oflags |= os.O_EXCL
@@ -90,10 +89,10 @@ func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, O
var err error
var f *os.File
if name == "" {
if name == nil {
f, err = os.CreateTemp("", "*.db")
} else {
f, err = osutil.OpenFile(name, oflags, 0666)
f, err = osutil.OpenFile(name.String(), oflags, 0666)
}
if err != nil {
if errors.Is(err, syscall.EISDIR) {
@@ -102,7 +101,7 @@ func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, O
return nil, flags, err
}
if modeof := params.Get("modeof"); modeof != "" {
if modeof := name.URIParameter("modeof"); modeof != "" {
if err = osSetMode(f, modeof); err != nil {
f.Close()
return nil, flags, _IOERR_FSTAT

166
vfs/filename.go Normal file
View File

@@ -0,0 +1,166 @@
package vfs
import (
"context"
"net/url"
"github.com/ncruces/go-sqlite3/internal/util"
"github.com/tetratelabs/wazero/api"
)
// Filename is used by SQLite to pass filenames
// to the Open method of a VFS.
//
// https://sqlite.org/c3ref/filename.html
type Filename struct {
ctx context.Context
mod api.Module
zPath uint32
flags OpenFlag
stack [2]uint64
}
// OpenFilename is an internal API users should not call directly.
func OpenFilename(ctx context.Context, mod api.Module, id uint32, flags OpenFlag) *Filename {
if id == 0 {
return nil
}
return &Filename{
ctx: ctx,
mod: mod,
zPath: id,
flags: flags,
}
}
// String returns this filename as a string.
func (n *Filename) String() string {
if n == nil || n.zPath == 0 {
return ""
}
return util.ReadString(n.mod, n.zPath, _MAX_PATHNAME)
}
// Database returns the name of the corresponding database file.
//
// https://sqlite.org/c3ref/filename_database.html
func (n *Filename) Database() string {
return n.path("sqlite3_filename_database")
}
// Journal returns the name of the corresponding rollback journal file.
//
// https://sqlite.org/c3ref/filename_database.html
func (n *Filename) Journal() string {
return n.path("sqlite3_filename_journal")
}
// Journal returns the name of the corresponding WAL file.
//
// https://sqlite.org/c3ref/filename_database.html
func (n *Filename) WAL() string {
return n.path("sqlite3_filename_wal")
}
func (n *Filename) path(method string) string {
if n == nil || n.zPath == 0 {
return ""
}
n.stack[0] = uint64(n.zPath)
fn := n.mod.ExportedFunction(method)
if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
return util.ReadString(n.mod, uint32(n.stack[0]), _MAX_PATHNAME)
}
// DatabaseFile returns the main database [File] corresponding to a journal.
//
// https://sqlite.org/c3ref/database_file_object.html
func (n *Filename) DatabaseFile() File {
if n == nil || n.zPath == 0 {
return nil
}
if n.flags&(OPEN_MAIN_DB|OPEN_MAIN_JOURNAL|OPEN_WAL) == 0 {
return nil
}
n.stack[0] = uint64(n.zPath)
fn := n.mod.ExportedFunction("sqlite3_database_file_object")
if err := fn.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
file, _ := vfsFileGet(n.ctx, n.mod, uint32(n.stack[0])).(File)
return file
}
// URIParameter returns the value of a URI parameter.
//
// https://sqlite.org/c3ref/uri_boolean.html
func (n *Filename) URIParameter(key string) string {
if n == nil || n.zPath == 0 {
return ""
}
uriKey := n.mod.ExportedFunction("sqlite3_uri_key")
uriParam := n.mod.ExportedFunction("sqlite3_uri_parameter")
for i := 0; ; i++ {
n.stack[1] = uint64(i)
n.stack[0] = uint64(n.zPath)
if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
if n.stack[0] == 0 {
return ""
}
if key != util.ReadString(n.mod, uint32(n.stack[0]), _MAX_NAME) {
continue
}
n.stack[1] = n.stack[0]
n.stack[0] = uint64(n.zPath)
if err := uriParam.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
return util.ReadString(n.mod, uint32(n.stack[0]), _MAX_NAME)
}
}
// URIParameters obtains values for URI parameters.
//
// https://sqlite.org/c3ref/uri_boolean.html
func (n *Filename) URIParameters() url.Values {
if n == nil || n.zPath == 0 {
return nil
}
var params url.Values
uriKey := n.mod.ExportedFunction("sqlite3_uri_key")
uriParam := n.mod.ExportedFunction("sqlite3_uri_parameter")
for i := 0; ; i++ {
n.stack[1] = uint64(i)
n.stack[0] = uint64(n.zPath)
if err := uriKey.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
if n.stack[0] == 0 {
return params
}
key := util.ReadString(n.mod, uint32(n.stack[0]), _MAX_NAME)
if params.Has(key) {
continue
}
n.stack[1] = n.stack[0]
n.stack[0] = uint64(n.zPath)
if err := uriParam.CallWithStack(n.ctx, n.stack[:]); err != nil {
panic(err)
}
if params == nil {
params = url.Values{}
}
params.Set(key, util.ReadString(n.mod, uint32(n.stack[0]), _MAX_NAME))
}
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"crypto/rand"
"io"
"net/url"
"reflect"
"sync"
"time"
@@ -141,15 +140,9 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
var file File
var err error
var parsed bool
var params url.Values
if jfs, ok := vfs.(VFSJournal); ok && flags&(OPEN_WAL|OPEN_MAIN_JOURNAL) != 0 {
db := vfsDatabaseFileObject(ctx, mod, zPath)
file, flags, err = jfs.OpenJournal(path, flags, db)
} else if pfs, ok := vfs.(VFSParams); ok {
parsed = true
params = vfsURIParameters(ctx, mod, zPath, flags)
file, flags, err = pfs.OpenParams(path, flags, params)
if ffs, ok := vfs.(VFSFilename); ok {
name := OpenFilename(ctx, mod, zPath, flags)
file, flags, err = ffs.OpenFilename(name, flags)
} else {
file, flags, err = vfs.Open(path, flags)
}
@@ -159,10 +152,8 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla
}
if file, ok := file.(FilePowersafeOverwrite); ok {
if !parsed {
params = vfsURIParameters(ctx, mod, zPath, flags)
}
if b, ok := util.ParseBool(params.Get("psow")); ok {
name := OpenFilename(ctx, mod, zPath, flags)
if b, ok := util.ParseBool(name.URIParameter("psow")); ok {
file.SetPowersafeOverwrite(b)
}
}
@@ -190,7 +181,7 @@ func vfsClose(ctx context.Context, mod api.Module, pFile uint32) _ErrorCode {
}
func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32, iOfst int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
buf := util.View(mod, zBuf, uint64(iAmt))
n, err := file.ReadAt(buf, iOfst)
@@ -205,7 +196,7 @@ func vfsRead(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32
}
func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int32, iOfst int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
buf := util.View(mod, zBuf, uint64(iAmt))
_, err := file.WriteAt(buf, iOfst)
@@ -216,38 +207,38 @@ func vfsWrite(ctx context.Context, mod api.Module, pFile, zBuf uint32, iAmt int3
}
func vfsTruncate(ctx context.Context, mod api.Module, pFile uint32, nByte int64) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Truncate(nByte)
return vfsErrorCode(err, _IOERR_TRUNCATE)
}
func vfsSync(ctx context.Context, mod api.Module, pFile uint32, flags SyncFlag) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Sync(flags)
return vfsErrorCode(err, _IOERR_FSYNC)
}
func vfsFileSize(ctx context.Context, mod api.Module, pFile, pSize uint32) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
size, err := file.Size()
util.WriteUint64(mod, pSize, uint64(size))
return vfsErrorCode(err, _IOERR_SEEK)
}
func vfsLock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Lock(eLock)
return vfsErrorCode(err, _IOERR_LOCK)
}
func vfsUnlock(ctx context.Context, mod api.Module, pFile uint32, eLock LockLevel) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
err := file.Unlock(eLock)
return vfsErrorCode(err, _IOERR_UNLOCK)
}
func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut uint32) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
locked, err := file.CheckReservedLock()
var res uint32
@@ -260,7 +251,7 @@ func vfsCheckReservedLock(ctx context.Context, mod api.Module, pFile, pResOut ui
}
func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _FcntlOpcode, pArg uint32) _ErrorCode {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
switch op {
case _FCNTL_LOCKSTATE:
@@ -351,12 +342,12 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl
}
func vfsSectorSize(ctx context.Context, mod api.Module, pFile uint32) uint32 {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
return uint32(file.SectorSize())
}
func vfsDeviceCharacteristics(ctx context.Context, mod api.Module, pFile uint32) DeviceCharacteristic {
file := vfsFileGet(ctx, mod, pFile)
file := vfsFileGet(ctx, mod, pFile).(File)
return file.DeviceCharacteristics()
}
@@ -389,56 +380,6 @@ func vfsShmUnmap(ctx context.Context, mod api.Module, pFile, bDelete uint32) _Er
return _OK
}
func vfsURIParameters(ctx context.Context, mod api.Module, zPath uint32, flags OpenFlag) url.Values {
switch {
case flags&(OPEN_URI|OPEN_MAIN_DB) == OPEN_URI|OPEN_MAIN_DB:
// database file with URI
case flags&(OPEN_WAL|OPEN_MAIN_JOURNAL) != 0:
// journal or WAL file
default:
return nil
}
var stack [2]uint64
var params url.Values
uriKey := mod.ExportedFunction("sqlite3_uri_key")
uriParam := mod.ExportedFunction("sqlite3_uri_parameter")
for i := 0; ; i++ {
stack[1] = uint64(i)
stack[0] = uint64(zPath)
if err := uriKey.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
if stack[0] == 0 {
return params
}
key := util.ReadString(mod, uint32(stack[0]), _MAX_NAME)
if params.Has(key) {
continue
}
stack[1] = stack[0]
stack[0] = uint64(zPath)
if err := uriParam.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
if params == nil {
params = url.Values{}
}
params.Set(key, util.ReadString(mod, uint32(stack[0]), _MAX_NAME))
}
}
func vfsDatabaseFileObject(ctx context.Context, mod api.Module, zPath uint32) File {
stack := [...]uint64{uint64(zPath)}
fn := mod.ExportedFunction("sqlite3_database_file_object")
if err := fn.CallWithStack(ctx, stack[:]); err != nil {
panic(err)
}
return vfsFileGet(ctx, mod, uint32(stack[0]))
}
func vfsGet(mod api.Module, pVfs uint32) VFS {
var name string
if pVfs != 0 {
@@ -457,10 +398,10 @@ func vfsFileRegister(ctx context.Context, mod api.Module, pFile uint32, file Fil
util.WriteUint32(mod, pFile+fileHandleOffset, id)
}
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) File {
func vfsFileGet(ctx context.Context, mod api.Module, pFile uint32) any {
const fileHandleOffset = 4
id := util.ReadUint32(mod, pFile+fileHandleOffset)
return util.GetHandle(ctx, id).(File)
return util.GetHandle(ctx, id)
}
func vfsFileClose(ctx context.Context, mod api.Module, pFile uint32) error {