Unknown collations.

This commit is contained in:
Nuno Cruces
2023-07-04 11:16:29 +01:00
parent 78ac2386f6
commit d3730341f0
4 changed files with 72 additions and 4 deletions

10
func.go
View File

@@ -8,6 +8,16 @@ import (
"github.com/tetratelabs/wazero/api"
)
// AnyCollationNeeded registers a fake collating function
// for any unknown collating sequence.
// The fake collating function works like BINARY.
//
// This extension can be used to load schemas that contain
// one or more unknown collating sequences.
func (c *Conn) AnyCollationNeeded() {
c.call(c.api.anyCollation, uint64(c.handle), 0, 0)
}
// CreateCollation defines a new collating sequence.
//
// https://www.sqlite.org/c3ref/create_collation.html

View File

@@ -8,6 +8,7 @@ import (
type handleKey struct{}
type handleState struct {
handles []any
empty int
}
func NewContext(ctx context.Context) (context.Context, io.Closer) {
@@ -24,6 +25,7 @@ func (s *handleState) Close() (err error) {
}
}
s.handles = nil
s.empty = 0
return err
}
@@ -42,6 +44,7 @@ func DelHandle(ctx context.Context, id uint32) error {
s := ctx.Value(handleKey{}).(*handleState)
a := s.handles[^id]
s.handles[^id] = nil
s.empty++
if c, ok := a.(io.Closer); ok {
return c.Close()
}
@@ -55,10 +58,13 @@ func AddHandle(ctx context.Context, a any) (id uint32) {
s := ctx.Value(handleKey{}).(*handleState)
// Find an empty slot.
for id, h := range s.handles {
if h == nil {
s.handles[id] = a
return ^uint32(id)
if s.empty > cap(s.handles)-len(s.handles) {
for id, h := range s.handles {
if h == nil {
s.empty--
s.handles[id] = a
return ^uint32(id)
}
}
}

View File

@@ -158,6 +158,7 @@ func newSQLite(mod api.Module) (sqlt *sqlite, err error) {
changes: getFun("sqlite3_changes64"),
lastRowid: getFun("sqlite3_last_insert_rowid"),
autocommit: getFun("sqlite3_get_autocommit"),
anyCollation: getFun("sqlite3_anycollseq_init"),
createCollation: getFun("sqlite3_create_collation_go"),
createFunction: getFun("sqlite3_create_function_go"),
createAggregate: getFun("sqlite3_create_aggregate_function_go"),
@@ -377,6 +378,7 @@ type sqliteAPI struct {
changes api.Function
lastRowid api.Function
autocommit api.Function
anyCollation api.Function
createCollation api.Function
createFunction api.Function
createAggregate api.Function

View File

@@ -136,3 +136,53 @@ func TestCreateFunction(t *testing.T) {
t.Errorf("got %v, want sqlite3.FULL", err)
}
}
func TestAnyCollationNeeded(t *testing.T) {
t.Parallel()
db, err := sqlite3.Open(":memory:")
if err != nil {
t.Fatal(err)
}
defer db.Close()
err = db.Exec(`CREATE TABLE IF NOT EXISTS users (id INT, name VARCHAR(10))`)
if err != nil {
t.Fatal(err)
}
err = db.Exec(`INSERT INTO users (id, name) VALUES (0, 'go'), (1, 'zig'), (2, 'whatever')`)
if err != nil {
t.Fatal(err)
}
db.AnyCollationNeeded()
stmt, _, err := db.Prepare(`SELECT id, name FROM users ORDER BY name COLLATE silly`)
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
row := 0
ids := []int{0, 2, 1}
names := []string{"go", "whatever", "zig"}
for ; stmt.Step(); row++ {
id := stmt.ColumnInt(0)
name := stmt.ColumnText(1)
if id != ids[row] {
t.Errorf("got %d, want %d", id, ids[row])
}
if name != names[row] {
t.Errorf("got %q, want %q", name, names[row])
}
}
if row != 3 {
t.Errorf("got %d, want %d", row, len(ids))
}
if err := stmt.Err(); err != nil {
t.Fatal(err)
}
}