GORM driver.

This commit is contained in:
Nuno Cruces
2023-06-06 12:37:54 +01:00
parent fbbbe5a631
commit f07e82e361
12 changed files with 1358 additions and 0 deletions

View File

@@ -17,6 +17,8 @@ provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
embeds a build of SQLite into your application. embeds a build of SQLite into your application.
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs) - Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation. wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
provides a [GORM](https://gorm.io) driver.
### Caveats ### Caveats

6
go.mod
View File

@@ -8,6 +8,12 @@ require (
github.com/tetratelabs/wazero v1.2.0 github.com/tetratelabs/wazero v1.2.0
golang.org/x/sync v0.2.0 golang.org/x/sync v0.2.0
golang.org/x/sys v0.8.0 golang.org/x/sys v0.8.0
gorm.io/gorm v1.25.1
)
require (
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
) )
retract v0.4.0 // tagged from the wrong branch retract v0.4.0 // tagged from the wrong branch

6
go.sum
View File

@@ -1,3 +1,7 @@
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk= github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE= github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
@@ -8,3 +12,5 @@ golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=

22
gormlite/LICENSE Normal file
View File

@@ -0,0 +1,22 @@
MIT License
Copyright (c) 2023 Nuno Cruces
Copyright (c) 2023 Jinzhu <wosmvp@gmail.com>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

24
gormlite/README.md Normal file
View File

@@ -0,0 +1,24 @@
# GORM SQLite Driver
## Usage
```go
import (
_ "github.com/ncruces/go-sqlite3/embed"
"github.com/ncruces/go-sqlite3/gormlite"
"gorm.io/gorm"
)
db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{})
```
Checkout [https://gorm.io](https://gorm.io) for details.
### Foreign-key constraint activation
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
```go
db, err := gorm.Open(gormlite.Open(
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"),
&gorm.Config{})
```

231
gormlite/ddlmod.go Normal file
View File

@@ -0,0 +1,231 @@
package gormlite
import (
"database/sql"
"errors"
"fmt"
"regexp"
"strconv"
"strings"
"gorm.io/gorm/migrator"
)
var (
sqliteSeparator = "`|\"|'|\t"
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
)
func getAllColumns(s string) []string {
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
columns := make([]string, 0, len(allMatches))
for _, matches := range allMatches {
if len(matches) > 1 {
columns = append(columns, matches[1])
}
}
return columns
}
type ddl struct {
head string
fields []string
columns []migrator.ColumnType
}
func parseDDL(strs ...string) (*ddl, error) {
var result ddl
for _, str := range strs {
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
var (
ddlBody = sections[2]
ddlBodyRunes = []rune(ddlBody)
bracketLevel int
quote rune
buf string
)
ddlBodyRunesLen := len(ddlBodyRunes)
result.head = sections[1]
for idx := 0; idx < ddlBodyRunesLen; idx++ {
var (
next rune = 0
c = ddlBodyRunes[idx]
)
if idx+1 < ddlBodyRunesLen {
next = ddlBodyRunes[idx+1]
}
if sc := string(c); separatorRegexp.MatchString(sc) {
if c == next {
buf += sc // Skip escaped quote
idx++
} else if quote > 0 {
quote = 0
} else {
quote = c
}
} else if quote == 0 {
if c == '(' {
bracketLevel++
} else if c == ')' {
bracketLevel--
} else if bracketLevel == 0 {
if c == ',' {
result.fields = append(result.fields, strings.TrimSpace(buf))
buf = ""
continue
}
}
}
if bracketLevel < 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
buf += string(c)
}
if bracketLevel != 0 {
return nil, errors.New("invalid DDL, unbalanced brackets")
}
if buf != "" {
result.fields = append(result.fields, strings.TrimSpace(buf))
}
for _, f := range result.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") {
continue
}
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
for _, name := range getAllColumns(f) {
for idx, column := range result.columns {
if column.NameValue.String == name {
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = column
break
}
}
}
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
columnType := migrator.ColumnType{
NameValue: sql.NullString{String: matches[1], Valid: true},
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
UniqueValue: sql.NullBool{Valid: true},
NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
}
matchUpper := strings.ToUpper(matches[3])
if strings.Contains(matchUpper, " NOT NULL") {
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
} else if strings.Contains(matchUpper, " NULL") {
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " UNIQUE") {
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
}
if strings.Contains(matchUpper, " PRIMARY") {
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
}
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
if strings.ToLower(defaultMatches[1]) != "null" {
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
}
}
// data type length
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
if len(matches) == 1 && len(matches[0]) == 2 {
size, _ := strconv.Atoi(matches[0][1])
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
}
result.columns = append(result.columns, columnType)
}
}
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
for _, column := range getAllColumns(matches[1]) {
for idx, c := range result.columns {
if c.NameValue.String == column {
c.UniqueValue = sql.NullBool{Bool: true, Valid: true}
result.columns[idx] = c
}
}
}
} else {
return nil, errors.New("invalid DDL")
}
}
return &result, nil
}
func (d *ddl) compile() string {
if len(d.fields) == 0 {
return d.head
}
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
}
func (d *ddl) addConstraint(name string, sql string) {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields[i] = sql
return
}
}
d.fields = append(d.fields, sql)
}
func (d *ddl) removeConstraint(name string) bool {
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
for i := 0; i < len(d.fields); i++ {
if reg.MatchString(d.fields[i]) {
d.fields = append(d.fields[:i], d.fields[i+1:]...)
return true
}
}
return false
}
func (d *ddl) getColumns() []string {
res := []string{}
for _, f := range d.fields {
fUpper := strings.ToUpper(f)
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
strings.HasPrefix(fUpper, "CHECK") ||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
continue
}
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
match := reg.FindStringSubmatch(f)
if match != nil {
res = append(res, "`"+match[1]+"`")
}
}
return res
}

334
gormlite/ddlmod_test.go Normal file
View File

@@ -0,0 +1,334 @@
package gormlite
import (
"database/sql"
"testing"
"gorm.io/gorm/migrator"
"gorm.io/gorm/utils/tests"
)
func TestParseDDL(t *testing.T) {
params := []struct {
name string
sql []string
nFields int
columns []migrator.ColumnType
}{
{"with_fk", []string{
"CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
}},
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{"no brackets", []string{"create table test"}, 0, nil},
{"with_special_characters", []string{
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
}, 1, []migrator.ColumnType{
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
},
},
{
"table_name_with_dash",
[]string{
"CREATE TABLE `test-a` (`id` int NOT NULL)",
"CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "int", Valid: true},
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Valid: true},
},
},
},
{
"unique index",
[]string{
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
},
1,
[]migrator.ColumnType{
{
NameValue: sql.NullString{String: "field", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
},
},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
ddl, err := parseDDL(p.sql...)
if err != nil {
panic(err.Error())
}
tests.AssertEqual(t, p.sql[0], ddl.compile())
if len(ddl.fields) != p.nFields {
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
}
tests.AssertEqual(t, ddl.columns, p.columns)
})
}
}
func TestParseDDL_Whitespaces(t *testing.T) {
testColumns := []migrator.ColumnType{
{
NameValue: sql.NullString{String: "id", Valid: true},
DataTypeValue: sql.NullString{String: "integer", Valid: true},
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
NullableValue: sql.NullBool{Bool: false, Valid: true},
DefaultValueValue: sql.NullString{Valid: false},
UniqueValue: sql.NullBool{Bool: true, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
},
{
NameValue: sql.NullString{String: "dark_mode", Valid: true},
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
NullableValue: sql.NullBool{Valid: true},
DefaultValueValue: sql.NullString{String: "true", Valid: true},
UniqueValue: sql.NullBool{Bool: false, Valid: true},
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
},
}
params := []struct {
name string
sql []string
nFields int
columns []migrator.ColumnType
}{
{
"with_newline",
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_newline_2",
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_missing_space",
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
{
"with_many_spaces",
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
2,
testColumns,
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
ddl, err := parseDDL(p.sql...)
if err != nil {
panic(err.Error())
}
if len(ddl.fields) != p.nFields {
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
}
tests.AssertEqual(t, ddl.columns, p.columns)
})
}
}
func TestParseDDL_error(t *testing.T) {
params := []struct {
name string
sql string
}{
{"invalid_cmd", "CREATE TABLE"},
{"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"},
{"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
_, err := parseDDL(p.sql)
if err == nil {
t.Fail()
}
})
}
}
func TestAddConstraint(t *testing.T) {
params := []struct {
name string
fields []string
cName string
sql string
expect []string
}{
{
name: "add_new",
fields: []string{"`id` integer NOT NULL"},
cName: "fk_users_notes",
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
},
{
name: "update",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
cName: "fk_users_notes",
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"},
},
{
name: "add_check",
fields: []string{"`id` integer NOT NULL"},
cName: "name_checker",
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
},
{
name: "update_check",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"},
cName: "name_checker",
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL := ddl{fields: p.fields}
testDDL.addConstraint(p.cName, p.sql)
tests.AssertEqual(t, p.expect, testDDL.fields)
})
}
}
func TestRemoveConstraint(t *testing.T) {
params := []struct {
name string
fields []string
cName string
success bool
expect []string
}{
{
name: "fk",
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
cName: "fk_users_notes",
success: true,
expect: []string{"`id` integer NOT NULL"},
},
{
name: "check",
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
cName: "name_checker",
success: true,
expect: []string{"`id` integer NOT NULL"},
},
{
name: "none",
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
cName: "nothing",
success: false,
expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL := ddl{fields: p.fields}
success := testDDL.removeConstraint(p.cName)
tests.AssertEqual(t, p.success, success)
tests.AssertEqual(t, p.expect, testDDL.fields)
})
}
}
func TestGetColumns(t *testing.T) {
params := []struct {
name string
ddl string
columns []string
}{
{
name: "with_fk",
ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
columns: []string{"`id`", "`text`", "`user_id`"},
},
{
name: "with_check",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))",
columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"},
},
{
name: "with_escaped_quote",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))",
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
},
{
name: "with_generated_column",
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))",
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
},
{
name: "with_new_line",
ddl: `CREATE TABLE "tb_sys_role_menu__temp" (
"id" integer PRIMARY KEY AUTOINCREMENT,
"created_at" datetime NOT NULL,
"updated_at" datetime NOT NULL,
"created_by" integer NOT NULL DEFAULT 0,
"updated_by" integer NOT NULL DEFAULT 0,
"role_id" integer NOT NULL,
"menu_id" bigint NOT NULL
)`,
columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"},
},
}
for _, p := range params {
t.Run(p.name, func(t *testing.T) {
testDDL, err := parseDDL(p.ddl)
if err != nil {
panic(err.Error())
}
cols := testDDL.getColumns()
tests.AssertEqual(t, p.columns, cols)
})
}
}

View File

@@ -0,0 +1,15 @@
package gormlite
import (
"errors"
"github.com/ncruces/go-sqlite3"
"gorm.io/gorm"
)
func (dialector Dialector) Translate(err error) error {
if errors.Is(err, sqlite3.CONSTRAINT_UNIQUE) {
return gorm.ErrDuplicatedKey
}
return err
}

7
gormlite/errors.go Normal file
View File

@@ -0,0 +1,7 @@
package gormlite
import "errors"
var (
ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator")
)

428
gormlite/migrator.go Normal file
View File

@@ -0,0 +1,428 @@
package gormlite
import (
"database/sql"
"fmt"
"regexp"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
)
type Migrator struct {
migrator.Migrator
}
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
var enabled int
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
if enabled == 1 {
m.DB.Exec("PRAGMA foreign_keys = OFF")
defer m.DB.Exec("PRAGMA foreign_keys = ON")
}
return fc()
}
func (m Migrator) HasTable(value interface{}) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
})
return count > 0
}
func (m Migrator) DropTable(values ...interface{}) error {
return m.RunWithoutForeignKey(func() error {
values = m.ReorderModels(values, false)
tx := m.DB.Session(&gorm.Session{})
for i := len(values) - 1; i >= 0; i-- {
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
}); err != nil {
return err
}
}
return nil
})
}
func (m Migrator) GetTables() (tableList []string, err error) {
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
}
func (m Migrator) HasColumn(value interface{}, name string) bool {
var count int
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) AlterColumn(value interface{}, name string) error {
return m.RunWithoutForeignKey(func() error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
if field := stmt.Schema.LookUpField(name); field != nil {
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
if createSQL == rawDDL {
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
}
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
}
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
})
})
}
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
columnTypes := make([]gorm.ColumnType, 0)
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
var (
sqls []string
sqlDDL *ddl
)
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
return err
}
if sqlDDL, err = parseDDL(sqls...); err != nil {
return err
}
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
if err != nil {
return err
}
defer func() {
err = rows.Close()
}()
var rawColumnTypes []*sql.ColumnType
rawColumnTypes, err = rows.ColumnTypes()
if err != nil {
return err
}
for _, c := range rawColumnTypes {
columnType := migrator.ColumnType{SQLColumnType: c}
for _, column := range sqlDDL.columns {
if column.NameValue.String == c.Name() {
column.SQLColumnType = c
columnType = column
break
}
}
columnTypes = append(columnTypes, columnType)
}
return err
})
return columnTypes, execErr
}
func (m Migrator) DropColumn(value interface{}, name string) error {
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
if field := stmt.Schema.LookUpField(name); field != nil {
name = field.DBName
}
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
if err != nil {
return "", nil, err
}
createSQL := reg.ReplaceAllString(rawDDL, "")
return createSQL, nil, nil
})
}
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
var (
constraintName string
constraintSql string
constraintValues []interface{}
)
if constraint != nil {
constraintName = constraint.Name
constraintSql, constraintValues = buildConstraint(constraint)
} else if chk != nil {
constraintName = chk.Name
constraintSql = "CONSTRAINT ? CHECK (?)"
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
} else {
return "", nil, nil
}
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.addConstraint(constraintName, constraintSql)
createSQL := createDDL.compile()
return createSQL, constraintValues, nil
})
})
}
func (m Migrator) DropConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
return m.recreateTable(value, &table,
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
createDDL, err := parseDDL(rawDDL)
if err != nil {
return "", nil, err
}
createDDL.removeConstraint(name)
createSQL := createDDL.compile()
return createSQL, nil, nil
})
})
}
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
if constraint != nil {
name = constraint.Name
} else if chk != nil {
name = chk.Name
}
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
).Row().Scan(&count)
return nil
})
return count > 0
}
func (m Migrator) CurrentDatabase() (name string) {
var null interface{}
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
return
}
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
for _, opt := range opts {
str := stmt.Quote(opt.DBName)
if opt.Expression != "" {
str = opt.Expression
}
if opt.Collate != "" {
str += " COLLATE " + opt.Collate
}
if opt.Sort != "" {
str += " " + opt.Sort
}
results = append(results, clause.Expr{SQL: str})
}
return
}
func (m Migrator) CreateIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
opts := m.BuildIndexOptions(idx.Fields, stmt)
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
createIndexSQL := "CREATE "
if idx.Class != "" {
createIndexSQL += idx.Class + " "
}
createIndexSQL += "INDEX ?"
if idx.Type != "" {
createIndexSQL += " USING " + idx.Type
}
createIndexSQL += " ON ??"
if idx.Where != "" {
createIndexSQL += " WHERE " + idx.Where
}
return m.DB.Exec(createIndexSQL, values...).Error
}
}
return fmt.Errorf("failed to create index with name %v", name)
})
}
func (m Migrator) HasIndex(value interface{}, name string) bool {
var count int
m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
if name != "" {
m.DB.Raw(
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
).Row().Scan(&count)
}
return nil
})
return count > 0
}
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
var sql string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
if sql != "" {
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
}
return fmt.Errorf("failed to find index with name %v", oldName)
})
}
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
if stmt.Schema != nil {
if idx := stmt.Schema.LookIndex(name); idx != nil {
name = idx.Name
}
}
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
})
}
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
if constraint.OnDelete != "" {
sql += " ON DELETE " + constraint.OnDelete
}
if constraint.OnUpdate != "" {
sql += " ON UPDATE " + constraint.OnUpdate
}
var foreignKeys, references []interface{}
for _, field := range constraint.ForeignKeys {
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
}
for _, field := range constraint.References {
references = append(references, clause.Column{Name: field.DBName})
}
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
return
}
func (m Migrator) getRawDDL(table string) (string, error) {
var createSQL string
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
if m.DB.Error != nil {
return "", m.DB.Error
}
return createSQL, nil
}
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
table := stmt.Table
if tablePtr != nil {
table = *tablePtr
}
rawDDL, err := m.getRawDDL(table)
if err != nil {
return err
}
newTableName := table + "__temp"
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
if err != nil {
return err
}
if createSQL == "" {
return nil
}
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
if err != nil {
return err
}
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
createDDL, err := parseDDL(createSQL)
if err != nil {
return err
}
columns := createDDL.getColumns()
return m.DB.Transaction(func(tx *gorm.DB) error {
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
return err
}
queries := []string{
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
fmt.Sprintf("DROP TABLE `%v`", table),
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
}
for _, query := range queries {
if err := tx.Exec(query).Error; err != nil {
return err
}
}
return nil
})
})
}

219
gormlite/sqlite.go Normal file
View File

@@ -0,0 +1,219 @@
// Package gormlite provides a GORM driver for SQLite.
package gormlite
import (
"context"
"database/sql"
"strconv"
"strings"
"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/migrator"
"gorm.io/gorm/schema"
_ "github.com/ncruces/go-sqlite3/driver"
)
type Dialector struct {
DSN string
Conn gorm.ConnPool
}
func Open(dsn string) gorm.Dialector {
return &Dialector{DSN: dsn}
}
func (dialector Dialector) Name() string {
return "sqlite"
}
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
if dialector.Conn != nil {
db.ConnPool = dialector.Conn
} else {
conn, err := sql.Open("sqlite3", dialector.DSN)
if err != nil {
return err
}
db.ConnPool = conn
}
var version string
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
return err
}
// https://www.sqlite.org/releaselog/3_35_0.html
if compareVersion(version, "3.35.0") >= 0 {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
LastInsertIDReversed: true,
})
} else {
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
LastInsertIDReversed: true,
})
}
for k, v := range dialector.ClauseBuilders() {
db.ClauseBuilders[k] = v
}
return
}
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
return map[string]clause.ClauseBuilder{
"INSERT": func(c clause.Clause, builder clause.Builder) {
if insert, ok := c.Expression.(clause.Insert); ok {
if stmt, ok := builder.(*gorm.Statement); ok {
stmt.WriteString("INSERT ")
if insert.Modifier != "" {
stmt.WriteString(insert.Modifier)
stmt.WriteByte(' ')
}
stmt.WriteString("INTO ")
if insert.Table.Name == "" {
stmt.WriteQuoted(stmt.Table)
} else {
stmt.WriteQuoted(insert.Table)
}
return
}
}
c.Build(builder)
},
"LIMIT": func(c clause.Clause, builder clause.Builder) {
if limit, ok := c.Expression.(clause.Limit); ok {
var lmt = -1
if limit.Limit != nil && *limit.Limit >= 0 {
lmt = *limit.Limit
}
if lmt >= 0 || limit.Offset > 0 {
builder.WriteString("LIMIT ")
builder.WriteString(strconv.Itoa(lmt))
}
if limit.Offset > 0 {
builder.WriteString(" OFFSET ")
builder.WriteString(strconv.Itoa(limit.Offset))
}
}
},
"FOR": func(c clause.Clause, builder clause.Builder) {
if _, ok := c.Expression.(clause.Locking); ok {
// SQLite3 does not support row-level locking.
return
}
c.Build(builder)
},
}
}
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
if field.AutoIncrement {
return clause.Expr{SQL: "NULL"}
}
// doesn't work, will raise error
return clause.Expr{SQL: "DEFAULT"}
}
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
return Migrator{migrator.Migrator{Config: migrator.Config{
DB: db,
Dialector: dialector,
CreateIndexAfterCreateTable: true,
}}}
}
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
writer.WriteByte('?')
}
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
writer.WriteByte('`')
if strings.Contains(str, ".") {
for idx, str := range strings.Split(str, ".") {
if idx > 0 {
writer.WriteString(".`")
}
writer.WriteString(str)
writer.WriteByte('`')
}
} else {
writer.WriteString(str)
writer.WriteByte('`')
}
}
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
return logger.ExplainSQL(sql, nil, `"`, vars...)
}
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
switch field.DataType {
case schema.Bool:
return "numeric"
case schema.Int, schema.Uint:
if field.AutoIncrement && !field.PrimaryKey {
// https://www.sqlite.org/autoinc.html
return "integer PRIMARY KEY AUTOINCREMENT"
} else {
return "integer"
}
case schema.Float:
return "real"
case schema.String:
return "text"
case schema.Time:
// Distinguish between schema.Time and tag time
if val, ok := field.TagSettings["TYPE"]; ok {
return val
} else {
return "datetime"
}
case schema.Bytes:
return "blob"
}
return string(field.DataType)
}
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
tx.Exec("SAVEPOINT " + name)
return nil
}
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
return nil
}
func compareVersion(version1, version2 string) int {
n, m := len(version1), len(version2)
i, j := 0, 0
for i < n || j < m {
x := 0
for ; i < n && version1[i] != '.'; i++ {
x = x*10 + int(version1[i]-'0')
}
i++
y := 0
for ; j < m && version2[j] != '.'; j++ {
y = y*10 + int(version2[j]-'0')
}
j++
if x > y {
return 1
}
if x < y {
return -1
}
}
return 0
}

64
gormlite/sqlite_test.go Normal file
View File

@@ -0,0 +1,64 @@
package gormlite
import (
"fmt"
"testing"
"gorm.io/gorm"
_ "github.com/ncruces/go-sqlite3/embed"
)
func TestDialector(t *testing.T) {
// This is the DSN of the in-memory SQLite database for these tests.
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
rows := []struct {
description string
dialector *Dialector
openSuccess bool
query string
querySuccess bool
}{
{
description: "Default driver",
dialector: &Dialector{
DSN: InMemoryDSN,
},
openSuccess: true,
query: "SELECT 1",
querySuccess: true,
},
}
for rowIndex, row := range rows {
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
db, err := gorm.Open(row.dialector, &gorm.Config{})
if !row.openSuccess {
if err == nil {
t.Errorf("Expected Open to fail.")
}
return
}
if err != nil {
t.Errorf("Expected Open to succeed; got error: %v", err)
}
if db == nil {
t.Errorf("Expected db to be non-nil.")
}
if row.query != "" {
err = db.Exec(row.query).Error
if !row.querySuccess {
if err == nil {
t.Errorf("Expected query to fail.")
}
return
}
if err != nil {
t.Errorf("Expected query to succeed; got error: %v", err)
}
}
})
}
}