diff --git a/ext/unicode/unicode.go b/ext/unicode/unicode.go new file mode 100644 index 0000000..113fc23 --- /dev/null +++ b/ext/unicode/unicode.go @@ -0,0 +1,168 @@ +// Package unicode provides a replacement for the SQLite ICU extension. +// +// Provides Unicode aware: +// - upper and lower functions, +// - LIKE and REGEX operators, +// - collation sequences. +package unicode + +import ( + "bytes" + "regexp" + "strings" + "unicode/utf8" + + "github.com/ncruces/go-sqlite3" + "github.com/ncruces/go-sqlite3/internal/util" + "golang.org/x/text/cases" + "golang.org/x/text/collate" + "golang.org/x/text/language" +) + +// Register registers Unicode aware functions for a database connection. +func Register(db sqlite3.Conn) { + flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS + + db.CreateFunction("like", 2, flags, like) + db.CreateFunction("like", 3, flags, like) + db.CreateFunction("upper", 1, flags, upper) + db.CreateFunction("upper", 2, flags, upper) + db.CreateFunction("lower", 1, flags, lower) + db.CreateFunction("lower", 2, flags, lower) + db.CreateFunction("regexp", 2, flags, regex) + db.CreateFunction("icu_load_collation", 2, sqlite3.DIRECTONLY, + func(ctx sqlite3.Context, arg ...sqlite3.Value) { + name := arg[1].Text() + if name == "" { + return + } + + tag, err := language.Parse(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return + } + + err = db.CreateCollation(name, collate.New(tag).Compare) + if err != nil { + ctx.ResultError(err) + return + } + }) +} + +func upper(ctx sqlite3.Context, arg ...sqlite3.Value) { + if len(arg) == 1 { + ctx.ResultBlob(bytes.ToUpper(arg[0].RawBlob())) + return + } + cs, ok := ctx.GetAuxData(1).(cases.Caser) + if !ok { + t, err := language.Parse(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return + } + c := cases.Upper(t) + ctx.SetAuxData(1, c) + cs = c + } + ctx.ResultBlob(cs.Bytes(arg[0].RawBlob())) +} + +func lower(ctx sqlite3.Context, arg ...sqlite3.Value) { + if len(arg) == 1 { + ctx.ResultBlob(bytes.ToLower(arg[0].RawBlob())) + return + } + cs, ok := ctx.GetAuxData(1).(cases.Caser) + if !ok { + t, err := language.Parse(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return + } + c := cases.Lower(t) + ctx.SetAuxData(1, c) + cs = c + } + ctx.ResultBlob(cs.Bytes(arg[0].RawBlob())) +} + +func regex(ctx sqlite3.Context, arg ...sqlite3.Value) { + re, ok := ctx.GetAuxData(0).(*regexp.Regexp) + if !ok { + r, err := regexp.Compile(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return + } + re = r + ctx.SetAuxData(0, re) + } + ctx.ResultBool(re.Match(arg[1].RawBlob())) +} + +func like(ctx sqlite3.Context, arg ...sqlite3.Value) { + escape := rune(-1) + if len(arg) == 3 { + var size int + b := arg[2].RawBlob() + escape, size = utf8.DecodeRune(b) + if size != len(b) { + ctx.ResultError(util.ErrorString("ESCAPE expression must be a single character")) + return + } + } + + type likeData struct { + *regexp.Regexp + escape rune + } + + re, ok := ctx.GetAuxData(0).(likeData) + if !ok || re.escape != escape { + r, err := regexp.Compile(like2regex(arg[0].Text(), escape)) + if err != nil { + ctx.ResultError(err) + return + } + re = likeData{r, escape} + ctx.SetAuxData(0, re) + } + ctx.ResultBool(re.Match(arg[1].RawBlob())) +} + +func like2regex(pattern string, escape rune) string { + var re strings.Builder + start := 0 + literal := false + re.WriteString(`(?is)`) // case insensitive, . matches any character + for i, r := range pattern { + if start < 0 { + start = i + } + if literal { + literal = false + continue + } + var symbol string + switch r { + case '_': + symbol = `.` + case '%': + symbol = `.*` + case escape: + literal = true + default: + continue + } + re.WriteString(regexp.QuoteMeta(pattern[start:i])) + re.WriteString(symbol) + start = -1 + } + if start >= 0 { + re.WriteString(regexp.QuoteMeta(pattern[start:])) + } + return re.String() +} diff --git a/ext/unicode/unicode_test.go b/ext/unicode/unicode_test.go new file mode 100644 index 0000000..1679c80 --- /dev/null +++ b/ext/unicode/unicode_test.go @@ -0,0 +1,26 @@ +package unicode + +import "testing" + +func Test_like2regex(t *testing.T) { + tests := []struct { + pattern string + escape rune + want string + }{ + {`a`, -1, `(?is)a`}, + {`a.`, -1, `(?is)a\.`}, + {`a%`, -1, `(?is)a.*`}, + {`a\`, -1, `(?is)a\\`}, + {`a_b`, -1, `(?is)a.b`}, + {`a|b`, '|', `(?is)ab`}, + {`a|_`, '|', `(?is)a_`}, + } + for _, tt := range tests { + t.Run(tt.pattern, func(t *testing.T) { + if got := like2regex(tt.pattern, tt.escape); got != tt.want { + t.Errorf("like2regex() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/func_test.go b/func_test.go index e9f6489..4778b46 100644 --- a/func_test.go +++ b/func_test.go @@ -146,7 +146,7 @@ func ExampleContext_SetAuxData() { ctx.SetAuxData(0, r) re = r } - ctx.ResultBool(re.Match(arg[1].RawText())) + ctx.ResultBool(re.Match(arg[1].RawBlob())) }) if err != nil { log.Fatal(err) diff --git a/func_win_test.go b/func_win_test.go index 87c28e2..ef23bff 100644 --- a/func_win_test.go +++ b/func_win_test.go @@ -87,7 +87,7 @@ func (f *countASCII) isASCII(arg sqlite3.Value) bool { if arg.Type() != sqlite3.TEXT { return false } - for _, c := range arg.RawText() { + for _, c := range arg.RawBlob() { if c > unicode.MaxASCII { return false }