mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-12 05:59:14 +00:00
IP/CIDR functions. (#246)
This commit is contained in:
113
ext/ipaddr/ipaddr.go
Normal file
113
ext/ipaddr/ipaddr.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Package ipaddr provides functions to manipulate IPs and CIDRs.
|
||||
//
|
||||
// It provides the following functions:
|
||||
// - ipcontains(prefix, ip)
|
||||
// - ipoverlaps(prefix1, prefix2)
|
||||
// - ipfamily(ip/prefix)
|
||||
// - iphost(ip/prefix)
|
||||
// - ipmasklen(prefix)
|
||||
// - ipnetwork(prefix)
|
||||
package ipaddr
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/netip"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
)
|
||||
|
||||
// Register IP/CIDR functions for a database connection.
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
return errors.Join(
|
||||
db.CreateFunction("ipcontains", 2, flags, contains),
|
||||
db.CreateFunction("ipoverlaps", 2, flags, overlaps),
|
||||
db.CreateFunction("ipfamily", 1, flags, family),
|
||||
db.CreateFunction("iphost", 1, flags, host),
|
||||
db.CreateFunction("ipmasklen", 1, flags, masklen),
|
||||
db.CreateFunction("ipnetwork", 1, flags, network))
|
||||
}
|
||||
|
||||
func contains(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
addr, err := netip.ParseAddr(arg[1].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
ctx.ResultBool(prefix.Contains(addr))
|
||||
}
|
||||
|
||||
func overlaps(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix1, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
prefix2, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
ctx.ResultBool(prefix1.Overlaps(prefix2))
|
||||
}
|
||||
|
||||
func family(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
addr, err := addr(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
switch {
|
||||
case addr.Is4():
|
||||
ctx.ResultInt(4)
|
||||
case addr.Is6():
|
||||
ctx.ResultInt(6)
|
||||
}
|
||||
}
|
||||
|
||||
func host(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
addr, err := addr(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
buf, _ := addr.MarshalText()
|
||||
ctx.ResultRawText(buf)
|
||||
}
|
||||
|
||||
func masklen(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
ctx.ResultInt(prefix.Bits())
|
||||
}
|
||||
|
||||
func network(ctx sqlite3.Context, arg ...sqlite3.Value) {
|
||||
prefix, err := netip.ParsePrefix(arg[0].Text())
|
||||
if err != nil {
|
||||
ctx.ResultError(err)
|
||||
return // notest
|
||||
}
|
||||
buf, _ := prefix.Masked().MarshalText()
|
||||
ctx.ResultRawText(buf)
|
||||
}
|
||||
|
||||
func addr(text string) (netip.Addr, error) {
|
||||
addr, err := netip.ParseAddr(text)
|
||||
if err != nil {
|
||||
if prefix, err := netip.ParsePrefix(text); err == nil {
|
||||
return prefix.Addr(), nil
|
||||
}
|
||||
if addrpt, err := netip.ParseAddrPort(text); err == nil {
|
||||
return addrpt.Addr(), nil
|
||||
}
|
||||
}
|
||||
return addr, err
|
||||
}
|
||||
88
ext/ipaddr/ipaddr_test.go
Normal file
88
ext/ipaddr/ipaddr_test.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package ipaddr_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ncruces/go-sqlite3/driver"
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/ext/ipaddr"
|
||||
_ "github.com/ncruces/go-sqlite3/internal/testcfg"
|
||||
"github.com/ncruces/go-sqlite3/vfs/memdb"
|
||||
)
|
||||
|
||||
func TestRegister(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmp := memdb.TestDB(t)
|
||||
|
||||
db, err := driver.Open(tmp, ipaddr.Register)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
var got string
|
||||
|
||||
err = db.QueryRow(`SELECT ipfamily('::1')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "6" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipfamily('[::1]:80')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "6" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipfamily('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "4" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT iphost('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "192.168.1.5" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipmasklen('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "24" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipnetwork('192.168.1.5/24')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "192.168.1.0/24" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipcontains('192.168.1.0/24', '192.168.1.5')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "1" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
|
||||
err = db.QueryRow(`SELECT ipoverlaps('192.168.1.0/24', '192.168.1.5/32')`).Scan(&got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "1" {
|
||||
t.Fatalf("got %s", got)
|
||||
}
|
||||
}
|
||||
@@ -1,22 +1,22 @@
|
||||
// Package unicode provides an alternative to the SQLite ICU extension.
|
||||
//
|
||||
// Like the [ICU extension], it provides Unicode aware:
|
||||
// - upper() and lower() functions,
|
||||
// - LIKE and REGEXP operators,
|
||||
// - collation sequences.
|
||||
// - upper() and lower() functions
|
||||
// - LIKE and REGEXP operators
|
||||
// - collation sequences
|
||||
//
|
||||
// Like PostgreSQL, it also provides:
|
||||
// - initcap(),
|
||||
// - casefold(),
|
||||
// - normalize(),
|
||||
// - unaccent().
|
||||
// - initcap()
|
||||
// - casefold()
|
||||
// - normalize()
|
||||
// - unaccent()
|
||||
//
|
||||
// The implementations are not 100% compatible:
|
||||
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases];
|
||||
// - normalize(), unaccent() use [transform] and [unicode.Mn];
|
||||
// - the LIKE operator follows [strings.EqualFold] rules;
|
||||
// - the REGEXP operator uses Go [regexp/syntax];
|
||||
// - collation sequences use [collate].
|
||||
// - upper(), lower(), initcap() casefold() use [strings.ToUpper], [strings.ToLower], [strings.Title] and [cases]
|
||||
// - normalize(), unaccent() use [transform] and [unicode.Mn]
|
||||
// - the LIKE operator follows [strings.EqualFold] rules
|
||||
// - the REGEXP operator uses Go [regexp/syntax]
|
||||
// - collation sequences use [collate]
|
||||
//
|
||||
// Expect subtle differences (e.g.) in the handling of Turkish case folding.
|
||||
//
|
||||
|
||||
@@ -18,17 +18,17 @@ import (
|
||||
// Register registers the SQL functions:
|
||||
//
|
||||
// - uuid([ version [, domain/namespace, [ id/data ]]]):
|
||||
// to generate a UUID as a string,
|
||||
// to generate a UUID as a string
|
||||
// - uuid_str(u):
|
||||
// to convert a UUID into a well-formed UUID string,
|
||||
// to convert a UUID into a well-formed UUID string
|
||||
// - uuid_blob(u):
|
||||
// to convert a UUID into a 16-byte blob,
|
||||
// to convert a UUID into a 16-byte blob
|
||||
// - uuid_extract_version(u):
|
||||
// to extract the version of a RFC 4122 UUID,
|
||||
// to extract the version of a RFC 4122 UUID
|
||||
// - uuid_extract_timestamp(u):
|
||||
// to extract the timestamp of a version 1/2/6/7 UUID,
|
||||
// to extract the timestamp of a version 1/2/6/7 UUID
|
||||
// - gen_random_uuid(u):
|
||||
// to generate a version 4 (random) UUID.
|
||||
// to generate a version 4 (random) UUID
|
||||
func Register(db *sqlite3.Conn) error {
|
||||
const flags = sqlite3.DETERMINISTIC | sqlite3.INNOCUOUS
|
||||
return errors.Join(
|
||||
|
||||
Reference in New Issue
Block a user