diff --git a/func.go b/func.go index d58a144..205943e 100644 --- a/func.go +++ b/func.go @@ -8,6 +8,16 @@ import ( "github.com/tetratelabs/wazero/api" ) +// AnyCollationNeeded registers a fake collating function +// for any unknown collating sequence. +// The fake collating function works like BINARY. +// +// This extension can be used to load schemas that contain +// one or more unknown collating sequences. +func (c *Conn) AnyCollationNeeded() { + c.call(c.api.anyCollation, uint64(c.handle), 0, 0) +} + // CreateCollation defines a new collating sequence. // // https://www.sqlite.org/c3ref/create_collation.html diff --git a/internal/util/handle.go b/internal/util/handle.go index 6aa1dc7..a1e5675 100644 --- a/internal/util/handle.go +++ b/internal/util/handle.go @@ -8,6 +8,7 @@ import ( type handleKey struct{} type handleState struct { handles []any + empty int } func NewContext(ctx context.Context) (context.Context, io.Closer) { @@ -24,6 +25,7 @@ func (s *handleState) Close() (err error) { } } s.handles = nil + s.empty = 0 return err } @@ -42,6 +44,7 @@ func DelHandle(ctx context.Context, id uint32) error { s := ctx.Value(handleKey{}).(*handleState) a := s.handles[^id] s.handles[^id] = nil + s.empty++ if c, ok := a.(io.Closer); ok { return c.Close() } @@ -55,10 +58,13 @@ func AddHandle(ctx context.Context, a any) (id uint32) { s := ctx.Value(handleKey{}).(*handleState) // Find an empty slot. - for id, h := range s.handles { - if h == nil { - s.handles[id] = a - return ^uint32(id) + if s.empty > cap(s.handles)-len(s.handles) { + for id, h := range s.handles { + if h == nil { + s.empty-- + s.handles[id] = a + return ^uint32(id) + } } } diff --git a/sqlite.go b/sqlite.go index e92525b..115adf7 100644 --- a/sqlite.go +++ b/sqlite.go @@ -158,6 +158,7 @@ func newSQLite(mod api.Module) (sqlt *sqlite, err error) { changes: getFun("sqlite3_changes64"), lastRowid: getFun("sqlite3_last_insert_rowid"), autocommit: getFun("sqlite3_get_autocommit"), + anyCollation: getFun("sqlite3_anycollseq_init"), createCollation: getFun("sqlite3_create_collation_go"), createFunction: getFun("sqlite3_create_function_go"), createAggregate: getFun("sqlite3_create_aggregate_function_go"), @@ -377,6 +378,7 @@ type sqliteAPI struct { changes api.Function lastRowid api.Function autocommit api.Function + anyCollation api.Function createCollation api.Function createFunction api.Function createAggregate api.Function diff --git a/tests/func_test.go b/tests/func_test.go index 1e408ef..b2e6768 100644 --- a/tests/func_test.go +++ b/tests/func_test.go @@ -136,3 +136,53 @@ func TestCreateFunction(t *testing.T) { t.Errorf("got %v, want sqlite3.FULL", err) } } + +func TestAnyCollationNeeded(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`) + if err != nil { + t.Fatal(err) + } + + db.AnyCollationNeeded() + + stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + row := 0 + ids := []int{0, 2, 1} + names := []string{"go", "whatever", "zig"} + for ; stmt.Step(); row++ { + id := stmt.ColumnInt(0) + name := stmt.ColumnText(1) + + if id != ids[row] { + t.Errorf("got %d, want %d", id, ids[row]) + } + if name != names[row] { + t.Errorf("got %q, want %q", name, names[row]) + } + } + if row != 3 { + t.Errorf("got %d, want %d", row, len(ids)) + } + + if err := stmt.Err(); err != nil { + t.Fatal(err) + } +}