mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
Regular expression extension. (#114)
This commit is contained in:
77
ext/regexp/regexp.go
Normal file
77
ext/regexp/regexp.go
Normal file
@@ -0,0 +1,77 @@
|
||||
// Package regexp provides additional regular expression functions.
|
||||
//
|
||||
// It provides the following Unicode aware functions:
|
||||
// - regexp_like(),
|
||||
// - regexp_substr(),
|
||||
// - regexp_replace(),
|
||||
// - and a REGEXP operator.
|
||||
//
|
||||
// The implementation uses Go [regexp/syntax] for regular expressions.
|
||||
//
|
||||
// https://github.com/nalgeon/sqlean/blob/main/docs/regexp.md
|
||||
package regexp
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register registers Unicode aware functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) {
|
||||
flags := sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
|
||||
db.CreateFunction("regexp", 2, flags, regex)
|
||||
db.CreateFunction("regexp_like", 2, flags, regexLike)
|
||||
db.CreateFunction("regexp_substr", 2, flags, regexSubstr)
|
||||
db.CreateFunction("regexp_replace", 3, flags, regexReplace)
|
||||
}
|
||||
|
||||
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
|
||||
re, ok := ctx.GetAuxData(i).(*regexp.Regexp)
|
||||
if !ok {
|
||||
r, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
re = r
|
||||
ctx.SetAuxData(0, r)
|
||||
}
|
||||
return re, nil
|
||||
}
|
||||
|
||||
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 0, arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
} else {
|
||||
ctx.ResultBool(re.Match(arg[1].RawText()))
|
||||
}
|
||||
}
|
||||
|
||||
func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
} else {
|
||||
ctx.ResultBool(re.Match(arg[0].RawText()))
|
||||
}
|
||||
}
|
||||
|
||||
func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
} else {
|
||||
ctx.ResultRawText(re.Find(arg[0].RawText()))
|
||||
}
|
||||
}
|
||||
|
||||
func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
re, err := load(ctx, 1, arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
} else {
|
||||
ctx.ResultRawText(re.ReplaceAll(arg[0].RawText(), arg[2].RawText()))
|
||||
}
|
||||
}
|
||||
75
ext/regexp/regexp_test.go
Normal file
75
ext/regexp/regexp_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package regexp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tests := []struct {
|
||||
test string
|
||||
want string
|
||||
}{
|
||||
{`'Hello' REGEXP 'elo'`, "0"},
|
||||
{`'Hello' REGEXP 'ell'`, "1"},
|
||||
{`'Hello' REGEXP 'el.'`, "1"},
|
||||
{`regexp_like('Hello', 'elo')`, "0"},
|
||||
{`regexp_like('Hello', 'ell')`, "1"},
|
||||
{`regexp_like('Hello', 'el.')`, "1"},
|
||||
{`regexp_substr('Hello', 'el.')`, "ell"},
|
||||
{`regexp_replace('Hello', 'llo', 'll')`, "Hell"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
var got string
|
||||
err := db.QueryRow(`SELECT ` + tt.test).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("got %q, want %q", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegister_errors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
tests := []string{
|
||||
`'' REGEXP ?`,
|
||||
`regexp_like('', ?)`,
|
||||
`regexp_substr('', ?)`,
|
||||
`regexp_replace('', ?, '')`,
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
err := db.QueryRow(`SELECT `+tt, `\`).Scan(nil)
|
||||
if err == nil {
|
||||
t.Fatal("want error")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -111,7 +111,7 @@ func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
return
|
||||
}
|
||||
re = r
|
||||
ctx.SetAuxData(0, re)
|
||||
ctx.SetAuxData(0, r)
|
||||
}
|
||||
ctx.ResultBool(re.Match(arg[1].RawText()))
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
)
|
||||
|
||||
func Test_generate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
@@ -130,6 +132,8 @@ func Test_generate(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_convert(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, err := driver.Open(":memory:", func(conn *sqlite3.Conn) error {
|
||||
Register(conn)
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user