This commit is contained in:
Nuno Cruces
2025-03-10 14:54:34 +00:00
parent 9e7a0a875d
commit 1ed954e96f
3 changed files with 53 additions and 32 deletions

View File

@@ -111,22 +111,24 @@ loop2:
return glob.String()
}
func load(ctx sqlite3.Context, i int, expr string) (*regexp.Regexp, error) {
func load(ctx sqlite3.Context, arg []sqlite3.Value, i int) (*regexp.Regexp, error) {
re, ok := ctx.GetAuxData(i).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(expr)
if err != nil {
return nil, err
re, ok = arg[i].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[i].Text())
if err != nil {
return nil, err
}
re = r
}
re = r
ctx.SetAuxData(0, r)
ctx.SetAuxData(i, re)
}
return re, nil
}
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
_ = arg[1] // bounds check
re, err := load(ctx, 0, arg[0].Text())
re, err := load(ctx, arg, 0)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -136,18 +138,17 @@ func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexLike(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
ctx.ResultBool(re.Match(text))
}
func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -162,7 +163,7 @@ func regexCount(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -187,7 +188,7 @@ func regexSubstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
@@ -215,16 +216,14 @@ func regexInstr(ctx sqlite3.Context, arg ...sqlite3.Value) {
}
func regexReplace(ctx sqlite3.Context, arg ...sqlite3.Value) {
_ = arg[2] // bounds check
re, err := load(ctx, 1, arg[1].Text())
re, err := load(ctx, arg, 1)
if err != nil {
ctx.ResultError(err)
return // notest
}
text := arg[0].RawText()
repl := arg[2].RawText()
text := arg[0].RawText()
var pos, n int
if len(arg) > 3 {
pos = arg[3].Int()

View File

@@ -6,6 +6,7 @@ import (
"strings"
"testing"
"github.com/ncruces/go-sqlite3"
"github.com/ncruces/go-sqlite3/driver"
_ "github.com/ncruces/go-sqlite3/embed"
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
@@ -104,6 +105,27 @@ func TestRegister_errors(t *testing.T) {
}
}
func TestRegister_pointer(t *testing.T) {
t.Parallel()
tmp := memdb.TestDB(t)
db, err := driver.Open(tmp, Register)
if err != nil {
t.Fatal(err)
}
defer db.Close()
var got int
err = db.QueryRow(`SELECT regexp_count('ABCABCAXYaxy', ?, 1)`,
sqlite3.Pointer(regexp.MustCompile(`(?i)A.`))).Scan(&got)
if err != nil {
t.Fatal(err)
}
if got != 4 {
t.Errorf("got %d, want %d", got, 4)
}
}
func TestGlobPrefix(t *testing.T) {
tests := []struct {
re string

View File

@@ -113,9 +113,8 @@ func upper(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
c := cases.Upper(t)
ctx.SetAuxData(1, c)
cs = c
cs = cases.Upper(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
@@ -132,9 +131,8 @@ func lower(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
c := cases.Lower(t)
ctx.SetAuxData(1, c)
cs = c
cs = cases.Lower(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
@@ -151,9 +149,8 @@ func initcap(ctx sqlite3.Context, arg ...sqlite3.Value) {
ctx.ResultError(err)
return // notest
}
c := cases.Title(t)
ctx.SetAuxData(1, c)
cs = c
cs = cases.Title(t)
ctx.SetAuxData(1, cs)
}
ctx.ResultRawText(cs.Bytes(arg[0].RawText()))
}
@@ -200,13 +197,16 @@ func normalize(ctx sqlite3.Context, arg ...sqlite3.Value) {
func regex(ctx sqlite3.Context, arg ...sqlite3.Value) {
re, ok := ctx.GetAuxData(0).(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
re, ok = arg[0].Pointer().(*regexp.Regexp)
if !ok {
r, err := regexp.Compile(arg[0].Text())
if err != nil {
ctx.ResultError(err)
return // notest
}
re = r
}
re = r
ctx.SetAuxData(0, r)
ctx.SetAuxData(0, re)
}
ctx.ResultBool(re.Match(arg[1].RawText()))
}