From b645721d1071c5cda24508ea85f32b0c5c74f992 Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Mon, 24 Mar 2025 22:38:22 +0000 Subject: [PATCH] IP/CIDR functions. (#246) --- ext/ipaddr/ipaddr.go | 113 ++++++++++++++++++++++++++++++++++++++ ext/ipaddr/ipaddr_test.go | 88 +++++++++++++++++++++++++++++ ext/unicode/unicode.go | 24 ++++---- ext/uuid/uuid.go | 12 ++-- 4 files changed, 219 insertions(+), 18 deletions(-) create mode 100644 ext/ipaddr/ipaddr.go create mode 100644 ext/ipaddr/ipaddr_test.go diff --git a/ext/ipaddr/ipaddr.go b/ext/ipaddr/ipaddr.go new file mode 100644 index 0000000..f45443b --- /dev/null +++ b/ext/ipaddr/ipaddr.go @@ -0,0 +1,113 @@ +// Package ipaddr provides functions to manipulate IPs and CIDRs. +// +// It provides the following functions: +// - ipcontains(prefix, ip) +// - ipoverlaps(prefix1, prefix2) +// - ipfamily(ip/prefix) +// - iphost(ip/prefix) +// - ipmasklen(prefix) +// - ipnetwork(prefix) +package ipaddr + +import ( + "errors" + "net/netip" + + "github.com/ncruces/go-sqlite3" +) + +// Register IP/CIDR functions for a database connection. +func Register(db *sqlite3.Conn) error { + const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS + return errors.Join( + db.CreateFunction("ipcontains", 2, flags, contains), + db.CreateFunction("ipoverlaps", 2, flags, overlaps), + db.CreateFunction("ipfamily", 1, flags, family), + db.CreateFunction("iphost", 1, flags, host), + db.CreateFunction("ipmasklen", 1, flags, masklen), + db.CreateFunction("ipnetwork", 1, flags, network)) +} + +func contains(ctx sqlite3.Context, arg ...sqlite3.Value) { + prefix, err := netip.ParsePrefix(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + addr, err := netip.ParseAddr(arg[1].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + ctx.ResultBool(prefix.Contains(addr)) +} + +func overlaps(ctx sqlite3.Context, arg ...sqlite3.Value) { + prefix1, err := netip.ParsePrefix(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + prefix2, err := netip.ParsePrefix(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + ctx.ResultBool(prefix1.Overlaps(prefix2)) +} + +func family(ctx sqlite3.Context, arg ...sqlite3.Value) { + addr, err := addr(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + switch { + case addr.Is4(): + ctx.ResultInt(4) + case addr.Is6(): + ctx.ResultInt(6) + } +} + +func host(ctx sqlite3.Context, arg ...sqlite3.Value) { + addr, err := addr(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + buf, _ := addr.MarshalText() + ctx.ResultRawText(buf) +} + +func masklen(ctx sqlite3.Context, arg ...sqlite3.Value) { + prefix, err := netip.ParsePrefix(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + ctx.ResultInt(prefix.Bits()) +} + +func network(ctx sqlite3.Context, arg ...sqlite3.Value) { + prefix, err := netip.ParsePrefix(arg[0].Text()) + if err != nil { + ctx.ResultError(err) + return // notest + } + buf, _ := prefix.Masked().MarshalText() + ctx.ResultRawText(buf) +} + +func addr(text string) (netip.Addr, error) { + addr, err := netip.ParseAddr(text) + if err != nil { + if prefix, err := netip.ParsePrefix(text); err == nil { + return prefix.Addr(), nil + } + if addrpt, err := netip.ParseAddrPort(text); err == nil { + return addrpt.Addr(), nil + } + } + return addr, err +} diff --git a/ext/ipaddr/ipaddr_test.go b/ext/ipaddr/ipaddr_test.go new file mode 100644 index 0000000..a93ed24 --- /dev/null +++ b/ext/ipaddr/ipaddr_test.go @@ -0,0 +1,88 @@ +package ipaddr_test + +import ( + "testing" + + "github.com/ncruces/go-sqlite3/driver" + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/ext/ipaddr" + _ "github.com/ncruces/go-sqlite3/internal/testcfg" + "github.com/ncruces/go-sqlite3/vfs/memdb" +) + +func TestRegister(t *testing.T) { + t.Parallel() + tmp := memdb.TestDB(t) + + db, err := driver.Open(tmp, ipaddr.Register) + if err != nil { + t.Fatal(err) + } + defer db.Close() + + var got string + + err = db.QueryRow(`SELECT ipfamily('::1')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "6" { + t.Fatalf("got %s", got) + } + + err = db.QueryRow(`SELECT ipfamily('[::1]:80')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "6" { + t.Fatalf("got %s", got) + } + + err = db.QueryRow(`SELECT ipfamily('192.168.1.5/24')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "4" { + t.Fatalf("got %s", got) + } + + err = db.QueryRow(`SELECT iphost('192.168.1.5/24')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "192.168.1.5" { + t.Fatalf("got %s", got) + } + + err = db.QueryRow(`SELECT ipmasklen('192.168.1.5/24')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "24" { + t.Fatalf("got %s", got) + } + + err = db.QueryRow(`SELECT ipnetwork('192.168.1.5/24')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "192.168.1.0/24" { + t.Fatalf("got %s", got) + } + + err = db.QueryRow(`SELECT ipcontains('192.168.1.0/24', '192.168.1.5')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "1" { + t.Fatalf("got %s", got) + } + + err = db.QueryRow(`SELECT ipoverlaps('192.168.1.0/24', '192.168.1.5/32')`).Scan(&got) + if err != nil { + t.Fatal(err) + } + if got != "1" { + t.Fatalf("got %s", got) + } +} diff --git a/ext/unicode/unicode.go b/ext/unicode/unicode.go index bf4b0ca..6abe62e 100644 --- a/ext/unicode/unicode.go +++ b/ext/unicode/unicode.go @@ -1,22 +1,22 @@ // Package unicode provides an alternative to the SQLite ICU extension. // // Like the [ICU extension], it provides Unicode aware: -// - upper() and lower() functions, -// - LIKE and REGEXP operators, -// - collation sequences. +// - upper() and lower() functions +// - LIKE and REGEXP operators +// - collation sequences // // Like PostgreSQL, it also provides: -// - initcap(), -// - casefold(), -// - normalize(), -// - unaccent(). +// - initcap() +// - casefold() +// - normalize() +// - unaccent() // // The implementations are not 100% compatible: -// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases]; -// - normalize(), unaccent() use [transform] and [unicode.Mn]; -// - the LIKE operator follows [strings.EqualFold] rules; -// - the REGEXP operator uses Go [regexp/syntax]; -// - collation sequences use [collate]. +// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases] +// - normalize(), unaccent() use [transform] and [unicode.Mn] +// - the LIKE operator follows [strings.EqualFold] rules +// - the REGEXP operator uses Go [regexp/syntax] +// - collation sequences use [collate] // // Expect subtle differences (e.g.) in the handling of Turkish case folding. // diff --git a/ext/uuid/uuid.go b/ext/uuid/uuid.go index a62a039..febf4de 100644 --- a/ext/uuid/uuid.go +++ b/ext/uuid/uuid.go @@ -18,17 +18,17 @@ import ( // Register registers the SQL functions: // // - uuid([ version [, domain/namespace, [ id/data ]]]): -// to generate a UUID as a string, +// to generate a UUID as a string // - uuid_str(u): -// to convert a UUID into a well-formed UUID string, +// to convert a UUID into a well-formed UUID string // - uuid_blob(u): -// to convert a UUID into a 16-byte blob, +// to convert a UUID into a 16-byte blob // - uuid_extract_version(u): -// to extract the version of a RFC 4122 UUID, +// to extract the version of a RFC 4122 UUID // - uuid_extract_timestamp(u): -// to extract the timestamp of a version 1/2/6/7 UUID, +// to extract the timestamp of a version 1/2/6/7 UUID // - gen_random_uuid(u): -// to generate a version 4 (random) UUID. +// to generate a version 4 (random) UUID func Register(db *sqlite3.Conn) error { const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS return errors.Join(