From ec9533b13f12537f92ff5f44907cc90e874f65f6 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 2 Jun 2023 02:31:15 +0100 Subject: [PATCH] Implement modeof. --- tests/conn_test.go | 39 +++++++++++++++++++++++++++++++++++++++ vfs/file.go | 11 +++++++++++ vfs/os_unix.go | 13 +++++++++++++ vfs/os_windows.go | 9 +++++++++ vfs/vfs.go | 6 ++++-- 5 files changed, 76 insertions(+), 2 deletions(-) diff --git a/tests/conn_test.go b/tests/conn_test.go index 745f84e..88ae11e 100644 --- a/tests/conn_test.go +++ b/tests/conn_test.go @@ -3,6 +3,8 @@ package tests import ( "context" "errors" + "os" + "path/filepath" "reflect" "strings" "testing" @@ -35,6 +37,43 @@ func TestConn_Open_notfound(t *testing.T) { } } +func TestConn_Open_modeof(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + file := filepath.Join(dir, "test.db") + mode := filepath.Join(dir, "modeof.txt") + + fd, err := os.OpenFile(mode, os.O_CREATE, 0624) + if err != nil { + t.Fatal(err) + } + fi, err := fd.Stat() + if err != nil { + t.Fatal(err) + } + fd.Close() + + db, err := sqlite3.Open("file:" + file + "?modeof=" + mode) + if err != nil { + t.Fatal(err) + } + di, err := os.Stat(file) + if err != nil { + t.Fatal(err) + } + db.Close() + + if di.Mode() != fi.Mode() { + t.Errorf("got %v, want %v", di.Mode(), fi.Mode()) + } + + _, err = sqlite3.Open("file:" + file + "?modeof=" + mode + "2") + if err == nil { + t.Fatal("want error") + } +} + func TestConn_Close(t *testing.T) { var conn *sqlite3.Conn conn.Close() diff --git a/vfs/file.go b/vfs/file.go index 3108296..84d1368 100644 --- a/vfs/file.go +++ b/vfs/file.go @@ -4,6 +4,7 @@ import ( "errors" "io" "io/fs" + "net/url" "os" "path/filepath" "runtime" @@ -68,6 +69,10 @@ func (vfsOS) Access(name string, flags AccessFlag) (bool, error) { } func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) { + return vfsOS{}.OpenParams(name, flags, nil) +} + +func (vfsOS) OpenParams(name string, flags OpenFlag, params url.Values) (File, OpenFlag, error) { var oflags int if flags&OPEN_EXCLUSIVE != 0 { oflags |= os.O_EXCL @@ -96,6 +101,12 @@ func (vfsOS) Open(name string, flags OpenFlag) (File, OpenFlag, error) { return nil, flags, err } + if modeof := params.Get("modeof"); modeof != "" { + if err = osSetMode(f, modeof); err != nil { + f.Close() + return nil, flags, _IOERR_FSTAT + } + } if flags&OPEN_DELETEONCLOSE != 0 { os.Remove(f.Name()) } diff --git a/vfs/os_unix.go b/vfs/os_unix.go index 04ab39f..345bbbf 100644 --- a/vfs/os_unix.go +++ b/vfs/os_unix.go @@ -5,6 +5,7 @@ package vfs import ( "io/fs" "os" + "syscall" "time" "golang.org/x/sys/unix" @@ -25,6 +26,18 @@ func osAccess(path string, flags AccessFlag) error { return unix.Access(path, access) } +func osSetMode(file *os.File, modeof string) error { + fi, err := os.Stat(modeof) + if err != nil { + return err + } + file.Chmod(fi.Mode()) + if sys, ok := fi.Sys().(*syscall.Stat_t); ok { + file.Chown(int(sys.Uid), int(sys.Gid)) + } + return nil +} + func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode { // Test the PENDING lock before acquiring a new SHARED lock. if pending, _ := osCheckLock(file, _PENDING_BYTE, 1); pending { diff --git a/vfs/os_windows.go b/vfs/os_windows.go index b004d4e..e942916 100644 --- a/vfs/os_windows.go +++ b/vfs/os_windows.go @@ -47,6 +47,15 @@ func osAccess(path string, flags AccessFlag) error { return nil } +func osSetMode(file *os.File, modeof string) error { + fi, err := os.Stat(modeof) + if err != nil { + return err + } + file.Chmod(fi.Mode()) + return nil +} + func osGetSharedLock(file *os.File, timeout time.Duration) _ErrorCode { // Acquire the PENDING lock temporarily before acquiring a new SHARED lock. rc := osReadLock(file, _PENDING_BYTE, 1, timeout) diff --git a/vfs/vfs.go b/vfs/vfs.go index cd432e6..eaa6dfd 100644 --- a/vfs/vfs.go +++ b/vfs/vfs.go @@ -175,8 +175,10 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla var file File var err error + var parsed bool var params url.Values if pfs, ok := vfs.(VFSParams); ok { + parsed = true params = vfsURIParameters(ctx, mod, zPath, flags) file, flags, err = pfs.OpenParams(path, flags, params) } else { @@ -184,7 +186,7 @@ func vfsOpen(ctx context.Context, mod api.Module, pVfs, zPath, pFile uint32, fla } if file, ok := file.(FilePowersafeOverwrite); ok { - if params == nil { + if !parsed { params = vfsURIParameters(ctx, mod, zPath, flags) } if b, ok := util.ParseBool(params.Get("psow")); ok { @@ -338,9 +340,9 @@ func vfsFileControl(ctx context.Context, mod api.Module, pFile uint32, op _Fcntl } // Consider also implementing these opcodes (in use by SQLite): + // _FCNTL_PDB // _FCNTL_BUSYHANDLER // _FCNTL_COMMIT_PHASETWO - // _FCNTL_PDB // _FCNTL_PRAGMA // _FCNTL_SYNC return _NOTFOUND