diff --git a/ext/fileio/coro.go b/ext/fileio/coro.go new file mode 100644 index 0000000..7ceca3d --- /dev/null +++ b/ext/fileio/coro.go @@ -0,0 +1,68 @@ +package fileio + +import ( + "fmt" + + "github.com/ncruces/go-sqlite3/internal/util" +) + +// Adapted from: https://research.swtch.com/coro + +const errCoroCanceled = util.ErrorString("coroutine canceled") + +func coroNew[In, Out any](f func(In, func(Out) In) Out) (resume func(In) (Out, bool), cancel func()) { + type msg[T any] struct { + panic any + val T + } + + cin := make(chan msg[In]) + cout := make(chan msg[Out]) + running := true + resume = func(in In) (out Out, ok bool) { + if !running { + return + } + cin <- msg[In]{val: in} + m := <-cout + if m.panic != nil { + panic(m.panic) + } + return m.val, running + } + cancel = func() { + if !running { + return + } + e := fmt.Errorf("%w", errCoroCanceled) + cin <- msg[In]{panic: e} + m := <-cout + if m.panic != nil && m.panic != e { + panic(m.panic) + } + } + yield := func(out Out) In { + cout <- msg[Out]{val: out} + m := <-cin + if m.panic != nil { + panic(m.panic) + } + return m.val + } + go func() { + defer func() { + if running { + running = false + cout <- msg[Out]{panic: recover()} + } + }() + var out Out + m := <-cin + if m.panic == nil { + out = f(m.val, yield) + } + running = false + cout <- msg[Out]{val: out} + }() + return resume, cancel +} diff --git a/ext/fileio/fsdir.go b/ext/fileio/fsdir.go index 995da76..4055b29 100644 --- a/ext/fileio/fsdir.go +++ b/ext/fileio/fsdir.go @@ -53,13 +53,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) { type cursor struct { fsdir - curr entry - next chan entry - done chan struct{} - base string - rowID int64 - eof bool - open bool + base string + resume func(struct{}) (entry, bool) + cancel func() + curr entry + eof bool + rowID int64 } type entry struct { @@ -69,11 +68,8 @@ type entry struct { } func (c *cursor) Close() error { - if c.open { - close(c.done) - s := <-c.next - c.open = false - return s.err + if c.cancel != nil { + c.cancel() } return nil } @@ -96,17 +92,25 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error { c.base = base } - c.rowID = 0 + c.resume, c.cancel = coroNew(func(_ struct{}, yield func(entry) struct{}) entry { + walkDir := func(path string, d fs.DirEntry, err error) error { + yield(entry{d, err, path}) + return nil + } + if c.fsys != nil { + fs.WalkDir(c.fsys, root, walkDir) + } else { + filepath.WalkDir(root, walkDir) + } + return entry{} + }) c.eof = false - c.open = true - c.next = make(chan entry) - c.done = make(chan struct{}) - go c.WalkDir(root) + c.rowID = 0 return c.Next() } func (c *cursor) Next() error { - curr, ok := <-c.next + curr, ok := c.resume(struct{}{}) c.curr = curr c.eof = !ok c.rowID++ @@ -166,22 +170,3 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error { } return nil } - -func (c *cursor) WalkDir(path string) { - defer close(c.next) - - if c.fsys != nil { - fs.WalkDir(c.fsys, path, c.WalkDirFunc) - } else { - filepath.WalkDir(path, c.WalkDirFunc) - } -} - -func (c *cursor) WalkDirFunc(path string, d fs.DirEntry, err error) error { - select { - case <-c.done: - return fs.SkipAll - case c.next <- entry{d, err, path}: - return nil - } -} diff --git a/ext/fileio/fsdir_test.go b/ext/fileio/fsdir_test.go index 1df5623..e199a34 100644 --- a/ext/fileio/fsdir_test.go +++ b/ext/fileio/fsdir_test.go @@ -28,7 +28,7 @@ func Test_fsdir(t *testing.T) { } defer db.Close() - rows, err := db.Query(`SELECT * FROM fsdir('.', '.') LIMIT 4`) + rows, err := db.Query(`SELECT * FROM fsdir('.', '.')`) if err != nil { t.Fatal(err) }