From 167025f47a6dbe380b37c2724d6d22723a967b4e Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Fri, 27 Sep 2024 15:45:51 +0100 Subject: [PATCH] On demand collations. --- ext/unicode/unicode.go | 9 +++++++ ext/unicode/unicode_test.go | 53 ++++++++++++++++++++++++++++++++++++- func.go | 6 ++++- 3 files changed, 66 insertions(+), 2 deletions(-) diff --git a/ext/unicode/unicode.go b/ext/unicode/unicode.go index 1d1861b..0aabd88 100644 --- a/ext/unicode/unicode.go +++ b/ext/unicode/unicode.go @@ -65,6 +65,15 @@ func RegisterCollation(db *sqlite3.Conn, locale, name string) error { return db.CreateCollation(name, collate.New(tag).Compare) } +// RegisterCollationsNeeded registers Unicode collation sequences on demand for a database connection. +func RegisterCollationsNeeded(db *sqlite3.Conn) error { + return db.CollationNeeded(func(db *sqlite3.Conn, name string) { + if tag, err := language.Parse(name); err == nil { + db.CreateCollation(name, collate.New(tag).Compare) + } + }) +} + func upper(ctx sqlite3.Context, arg ...sqlite3.Value) { if len(arg) == 1 { ctx.ResultRawText(bytes.ToUpper(arg[0].RawText())) diff --git a/ext/unicode/unicode_test.go b/ext/unicode/unicode_test.go index 172b02a..3eb585a 100644 --- a/ext/unicode/unicode_test.go +++ b/ext/unicode/unicode_test.go @@ -92,7 +92,7 @@ func TestRegister_collation(t *testing.T) { t.Fatal(err) } - err = db.Exec(`SELECT icu_load_collation('fr_FR', 'french')`) + err = db.Exec(`SELECT icu_load_collation('fr-FR', 'french')`) if err != nil { t.Fatal(err) } @@ -127,6 +127,57 @@ func TestRegister_collation(t *testing.T) { } } +func TestRegisterCollationsNeeded(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + RegisterCollationsNeeded(db) + + err = db.Exec(`CREATE TABLE words (word VARCHAR(10))`) + if err != nil { + t.Fatal(err) + } + + err = db.Exec(`INSERT INTO words (word) VALUES ('côte'), ('cote'), ('coter'), ('coté'), ('cotée'), ('côté')`) + if err != nil { + t.Fatal(err) + } + + stmt, _, err := db.Prepare(`SELECT word FROM words ORDER BY word COLLATE fr_FR`) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + got, want := []string{}, []string{"cote", "coté", "côte", "côté", "cotée", "coter"} + + for stmt.Step() { + got = append(got, stmt.ColumnText(0)) + } + if err := stmt.Err(); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(got, want) { + t.Error("not equal") + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + err = db.Close() + if err != nil { + t.Fatal(err) + } +} + func TestRegister_error(t *testing.T) { t.Parallel() diff --git a/func.go b/func.go index ab486e7..4eac249 100644 --- a/func.go +++ b/func.go @@ -33,7 +33,11 @@ func (c *Conn) CollationNeeded(cb func(db *Conn, name string)) error { // one or more unknown collating sequences. func (c Conn) AnyCollationNeeded() error { r := c.call("sqlite3_anycollseq_init", uint64(c.handle), 0, 0) - return c.error(r) + if err := c.error(r); err != nil { + return err + } + c.collation = nil + return nil } // CreateCollation defines a new collating sequence.