mirror of
https://github.com/ncruces/go-sqlite3.git
synced 2026-01-11 21:49:13 +00:00
GORM driver.
This commit is contained in:
@@ -17,6 +17,8 @@ provides a [`database/sql`](https://pkg.go.dev/database/sql) driver
|
||||
embeds a build of SQLite into your application.
|
||||
- Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs)
|
||||
wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation.
|
||||
- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite)
|
||||
provides a [GORM](https://gorm.io) driver.
|
||||
|
||||
### Caveats
|
||||
|
||||
|
||||
6
go.mod
6
go.mod
@@ -8,6 +8,12 @@ require (
|
||||
github.com/tetratelabs/wazero v1.2.0
|
||||
golang.org/x/sync v0.2.0
|
||||
golang.org/x/sys v0.8.0
|
||||
gorm.io/gorm v1.25.1
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
)
|
||||
|
||||
retract v0.4.0 // tagged from the wrong branch
|
||||
|
||||
6
go.sum
6
go.sum
@@ -1,3 +1,7 @@
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk=
|
||||
github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g=
|
||||
github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE=
|
||||
@@ -8,3 +12,5 @@ golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI=
|
||||
golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64=
|
||||
gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k=
|
||||
|
||||
22
gormlite/LICENSE
Normal file
22
gormlite/LICENSE
Normal file
@@ -0,0 +1,22 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Nuno Cruces
|
||||
Copyright (c) 2023 Jinzhu <wosmvp@gmail.com>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
24
gormlite/README.md
Normal file
24
gormlite/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# GORM SQLite Driver
|
||||
|
||||
## Usage
|
||||
|
||||
```go
|
||||
import (
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
"github.com/ncruces/go-sqlite3/gormlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{})
|
||||
```
|
||||
|
||||
Checkout [https://gorm.io](https://gorm.io) for details.
|
||||
|
||||
### Foreign-key constraint activation
|
||||
|
||||
Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter:
|
||||
```go
|
||||
db, err := gorm.Open(gormlite.Open(
|
||||
"file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"),
|
||||
&gorm.Config{})
|
||||
```
|
||||
231
gormlite/ddlmod.go
Normal file
231
gormlite/ddlmod.go
Normal file
@@ -0,0 +1,231 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
)
|
||||
|
||||
var (
|
||||
sqliteSeparator = "`|\"|'|\t"
|
||||
indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator))
|
||||
separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator))
|
||||
columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator))
|
||||
columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator))
|
||||
defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`)
|
||||
regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`)
|
||||
)
|
||||
|
||||
func getAllColumns(s string) []string {
|
||||
allMatches := columnsRegexp.FindAllStringSubmatch(s, -1)
|
||||
columns := make([]string, 0, len(allMatches))
|
||||
for _, matches := range allMatches {
|
||||
if len(matches) > 1 {
|
||||
columns = append(columns, matches[1])
|
||||
}
|
||||
}
|
||||
return columns
|
||||
}
|
||||
|
||||
type ddl struct {
|
||||
head string
|
||||
fields []string
|
||||
columns []migrator.ColumnType
|
||||
}
|
||||
|
||||
func parseDDL(strs ...string) (*ddl, error) {
|
||||
var result ddl
|
||||
for _, str := range strs {
|
||||
if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 {
|
||||
var (
|
||||
ddlBody = sections[2]
|
||||
ddlBodyRunes = []rune(ddlBody)
|
||||
bracketLevel int
|
||||
quote rune
|
||||
buf string
|
||||
)
|
||||
ddlBodyRunesLen := len(ddlBodyRunes)
|
||||
|
||||
result.head = sections[1]
|
||||
|
||||
for idx := 0; idx < ddlBodyRunesLen; idx++ {
|
||||
var (
|
||||
next rune = 0
|
||||
c = ddlBodyRunes[idx]
|
||||
)
|
||||
if idx+1 < ddlBodyRunesLen {
|
||||
next = ddlBodyRunes[idx+1]
|
||||
}
|
||||
|
||||
if sc := string(c); separatorRegexp.MatchString(sc) {
|
||||
if c == next {
|
||||
buf += sc // Skip escaped quote
|
||||
idx++
|
||||
} else if quote > 0 {
|
||||
quote = 0
|
||||
} else {
|
||||
quote = c
|
||||
}
|
||||
} else if quote == 0 {
|
||||
if c == '(' {
|
||||
bracketLevel++
|
||||
} else if c == ')' {
|
||||
bracketLevel--
|
||||
} else if bracketLevel == 0 {
|
||||
if c == ',' {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
buf = ""
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if bracketLevel < 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
buf += string(c)
|
||||
}
|
||||
|
||||
if bracketLevel != 0 {
|
||||
return nil, errors.New("invalid DDL, unbalanced brackets")
|
||||
}
|
||||
|
||||
if buf != "" {
|
||||
result.fields = append(result.fields, strings.TrimSpace(buf))
|
||||
}
|
||||
|
||||
for _, f := range result.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") {
|
||||
for _, name := range getAllColumns(f) {
|
||||
for idx, column := range result.columns {
|
||||
if column.NameValue.String == name {
|
||||
column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
result.columns[idx] = column
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 {
|
||||
columnType := migrator.ColumnType{
|
||||
NameValue: sql.NullString{String: matches[1], Valid: true},
|
||||
DataTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: matches[2], Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
UniqueValue: sql.NullBool{Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
}
|
||||
|
||||
matchUpper := strings.ToUpper(matches[3])
|
||||
if strings.Contains(matchUpper, " NOT NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: false, Valid: true}
|
||||
} else if strings.Contains(matchUpper, " NULL") {
|
||||
columnType.NullableValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " UNIQUE") {
|
||||
columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if strings.Contains(matchUpper, " PRIMARY") {
|
||||
columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true}
|
||||
}
|
||||
if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 {
|
||||
if strings.ToLower(defaultMatches[1]) != "null" {
|
||||
columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true}
|
||||
}
|
||||
}
|
||||
|
||||
// data type length
|
||||
matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1)
|
||||
if len(matches) == 1 && len(matches[0]) == 2 {
|
||||
size, _ := strconv.Atoi(matches[0][1])
|
||||
columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)}
|
||||
columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0])
|
||||
}
|
||||
|
||||
result.columns = append(result.columns, columnType)
|
||||
}
|
||||
}
|
||||
} else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 {
|
||||
for _, column := range getAllColumns(matches[1]) {
|
||||
for idx, c := range result.columns {
|
||||
if c.NameValue.String == column {
|
||||
c.UniqueValue = sql.NullBool{Bool: true, Valid: true}
|
||||
result.columns[idx] = c
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("invalid DDL")
|
||||
}
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
func (d *ddl) compile() string {
|
||||
if len(d.fields) == 0 {
|
||||
return d.head
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ","))
|
||||
}
|
||||
|
||||
func (d *ddl) addConstraint(name string, sql string) {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields[i] = sql
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
d.fields = append(d.fields, sql)
|
||||
}
|
||||
|
||||
func (d *ddl) removeConstraint(name string) bool {
|
||||
reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]")
|
||||
|
||||
for i := 0; i < len(d.fields); i++ {
|
||||
if reg.MatchString(d.fields[i]) {
|
||||
d.fields = append(d.fields[:i], d.fields[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (d *ddl) getColumns() []string {
|
||||
res := []string{}
|
||||
|
||||
for _, f := range d.fields {
|
||||
fUpper := strings.ToUpper(f)
|
||||
if strings.HasPrefix(fUpper, "PRIMARY KEY") ||
|
||||
strings.HasPrefix(fUpper, "CHECK") ||
|
||||
strings.HasPrefix(fUpper, "CONSTRAINT") ||
|
||||
strings.Contains(fUpper, "GENERATED ALWAYS AS") {
|
||||
continue
|
||||
}
|
||||
|
||||
reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?")
|
||||
match := reg.FindStringSubmatch(f)
|
||||
|
||||
if match != nil {
|
||||
res = append(res, "`"+match[1]+"`")
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
334
gormlite/ddlmod_test.go
Normal file
334
gormlite/ddlmod_test.go
Normal file
@@ -0,0 +1,334 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/utils/tests"
|
||||
)
|
||||
|
||||
func TestParseDDL(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{"with_fk", []string{
|
||||
"CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
"CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)",
|
||||
}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}},
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
{NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
}},
|
||||
{"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{"no brackets", []string{"create table test"}, 0, nil},
|
||||
{"with_special_characters", []string{
|
||||
"CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")",
|
||||
}, 1, []migrator.ColumnType{
|
||||
{NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"table_name_with_dash",
|
||||
[]string{
|
||||
"CREATE TABLE `test-a` (`id` int NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "int", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"unique index",
|
||||
[]string{
|
||||
"CREATE TABLE `test-b` (`field` integer NOT NULL)",
|
||||
"CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0",
|
||||
},
|
||||
1,
|
||||
[]migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "field", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
tests.AssertEqual(t, p.sql[0], ddl.compile())
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_Whitespaces(t *testing.T) {
|
||||
testColumns := []migrator.ColumnType{
|
||||
{
|
||||
NameValue: sql.NullString{String: "id", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "integer", Valid: true},
|
||||
NullableValue: sql.NullBool{Bool: false, Valid: true},
|
||||
DefaultValueValue: sql.NullString{Valid: false},
|
||||
UniqueValue: sql.NullBool{Bool: true, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true},
|
||||
},
|
||||
{
|
||||
NameValue: sql.NullString{String: "dark_mode", Valid: true},
|
||||
DataTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
ColumnTypeValue: sql.NullString{String: "numeric", Valid: true},
|
||||
NullableValue: sql.NullBool{Valid: true},
|
||||
DefaultValueValue: sql.NullString{String: "true", Valid: true},
|
||||
UniqueValue: sql.NullBool{Bool: false, Valid: true},
|
||||
PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true},
|
||||
},
|
||||
}
|
||||
|
||||
params := []struct {
|
||||
name string
|
||||
sql []string
|
||||
nFields int
|
||||
columns []migrator.ColumnType
|
||||
}{
|
||||
{
|
||||
"with_newline",
|
||||
[]string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_newline_2",
|
||||
[]string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_missing_space",
|
||||
[]string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
{
|
||||
"with_many_spaces",
|
||||
[]string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"},
|
||||
2,
|
||||
testColumns,
|
||||
},
|
||||
}
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
ddl, err := parseDDL(p.sql...)
|
||||
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
if len(ddl.fields) != p.nFields {
|
||||
t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields))
|
||||
}
|
||||
tests.AssertEqual(t, ddl.columns, p.columns)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseDDL_error(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
sql string
|
||||
}{
|
||||
{"invalid_cmd", "CREATE TABLE"},
|
||||
{"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"},
|
||||
{"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
_, err := parseDDL(p.sql)
|
||||
if err == nil {
|
||||
t.Fail()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
sql string
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "add_new",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
},
|
||||
{
|
||||
name: "update",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"},
|
||||
},
|
||||
{
|
||||
name: "add_check",
|
||||
fields: []string{"`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
{
|
||||
name: "update_check",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"},
|
||||
cName: "name_checker",
|
||||
sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')",
|
||||
expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
testDDL.addConstraint(p.cName, p.sql)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveConstraint(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
fields []string
|
||||
cName string
|
||||
success bool
|
||||
expect []string
|
||||
}{
|
||||
{
|
||||
name: "fk",
|
||||
fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"},
|
||||
cName: "fk_users_notes",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "check",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "name_checker",
|
||||
success: true,
|
||||
expect: []string{"`id` integer NOT NULL"},
|
||||
},
|
||||
{
|
||||
name: "none",
|
||||
fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
cName: "nothing",
|
||||
success: false,
|
||||
expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL := ddl{fields: p.fields}
|
||||
|
||||
success := testDDL.removeConstraint(p.cName)
|
||||
|
||||
tests.AssertEqual(t, p.success, success)
|
||||
tests.AssertEqual(t, p.expect, testDDL.fields)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetColumns(t *testing.T) {
|
||||
params := []struct {
|
||||
name string
|
||||
ddl string
|
||||
columns []string
|
||||
}{
|
||||
{
|
||||
name: "with_fk",
|
||||
ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))",
|
||||
columns: []string{"`id`", "`text`", "`user_id`"},
|
||||
},
|
||||
{
|
||||
name: "with_check",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"},
|
||||
},
|
||||
{
|
||||
name: "with_escaped_quote",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_generated_column",
|
||||
ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))",
|
||||
columns: []string{"`ID`", "`LastName`", "`FirstName`"},
|
||||
},
|
||||
{
|
||||
name: "with_new_line",
|
||||
ddl: `CREATE TABLE "tb_sys_role_menu__temp" (
|
||||
"id" integer PRIMARY KEY AUTOINCREMENT,
|
||||
"created_at" datetime NOT NULL,
|
||||
"updated_at" datetime NOT NULL,
|
||||
"created_by" integer NOT NULL DEFAULT 0,
|
||||
"updated_by" integer NOT NULL DEFAULT 0,
|
||||
"role_id" integer NOT NULL,
|
||||
"menu_id" bigint NOT NULL
|
||||
)`,
|
||||
columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, p := range params {
|
||||
t.Run(p.name, func(t *testing.T) {
|
||||
testDDL, err := parseDDL(p.ddl)
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
cols := testDDL.getColumns()
|
||||
|
||||
tests.AssertEqual(t, p.columns, cols)
|
||||
})
|
||||
}
|
||||
}
|
||||
15
gormlite/error_translator.go
Normal file
15
gormlite/error_translator.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ncruces/go-sqlite3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func (dialector Dialector) Translate(err error) error {
|
||||
if errors.Is(err, sqlite3.CONSTRAINT_UNIQUE) {
|
||||
return gorm.ErrDuplicatedKey
|
||||
}
|
||||
return err
|
||||
}
|
||||
7
gormlite/errors.go
Normal file
7
gormlite/errors.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package gormlite
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator")
|
||||
)
|
||||
428
gormlite/migrator.go
Normal file
428
gormlite/migrator.go
Normal file
@@ -0,0 +1,428 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
type Migrator struct {
|
||||
migrator.Migrator
|
||||
}
|
||||
|
||||
func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
|
||||
var enabled int
|
||||
m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
|
||||
if enabled == 1 {
|
||||
m.DB.Exec("PRAGMA foreign_keys = OFF")
|
||||
defer m.DB.Exec("PRAGMA foreign_keys = ON")
|
||||
}
|
||||
|
||||
return fc()
|
||||
}
|
||||
|
||||
func (m Migrator) HasTable(value interface{}) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) DropTable(values ...interface{}) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
values = m.ReorderModels(values, false)
|
||||
tx := m.DB.Session(&gorm.Session{})
|
||||
|
||||
for i := len(values) - 1; i >= 0; i-- {
|
||||
if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
|
||||
return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) GetTables() (tableList []string, err error) {
|
||||
return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
|
||||
}
|
||||
|
||||
func (m Migrator) HasColumn(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) AlterColumn(value interface{}, name string) error {
|
||||
return m.RunWithoutForeignKey(func() error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
// lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)`
|
||||
reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName))
|
||||
|
||||
if createSQL == rawDDL {
|
||||
return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL)
|
||||
}
|
||||
|
||||
return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil
|
||||
}
|
||||
return "", nil, fmt.Errorf("failed to alter field with name %v", name)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// ColumnTypes return columnTypes []gorm.ColumnType and execErr error
|
||||
func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
|
||||
columnTypes := make([]gorm.ColumnType, 0)
|
||||
execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
|
||||
var (
|
||||
sqls []string
|
||||
sqlDDL *ddl
|
||||
)
|
||||
|
||||
if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if sqlDDL, err = parseDDL(sqls...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
err = rows.Close()
|
||||
}()
|
||||
|
||||
var rawColumnTypes []*sql.ColumnType
|
||||
rawColumnTypes, err = rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, c := range rawColumnTypes {
|
||||
columnType := migrator.ColumnType{SQLColumnType: c}
|
||||
for _, column := range sqlDDL.columns {
|
||||
if column.NameValue.String == c.Name() {
|
||||
column.SQLColumnType = c
|
||||
columnType = column
|
||||
break
|
||||
}
|
||||
}
|
||||
columnTypes = append(columnTypes, columnType)
|
||||
}
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return columnTypes, execErr
|
||||
}
|
||||
|
||||
func (m Migrator) DropColumn(value interface{}, name string) error {
|
||||
return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
if field := stmt.Schema.LookUpField(name); field != nil {
|
||||
name = field.DBName
|
||||
}
|
||||
|
||||
reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,")
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
createSQL := reg.ReplaceAllString(rawDDL, "")
|
||||
|
||||
return createSQL, nil, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) CreateConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
var (
|
||||
constraintName string
|
||||
constraintSql string
|
||||
constraintValues []interface{}
|
||||
)
|
||||
|
||||
if constraint != nil {
|
||||
constraintName = constraint.Name
|
||||
constraintSql, constraintValues = buildConstraint(constraint)
|
||||
} else if chk != nil {
|
||||
constraintName = chk.Name
|
||||
constraintSql = "CONSTRAINT ? CHECK (?)"
|
||||
constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}}
|
||||
} else {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.addConstraint(constraintName, constraintSql)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, constraintValues, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropConstraint(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
return m.recreateTable(value, &table,
|
||||
func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) {
|
||||
createDDL, err := parseDDL(rawDDL)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
createDDL.removeConstraint(name)
|
||||
createSQL := createDDL.compile()
|
||||
|
||||
return createSQL, nil, nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasConstraint(value interface{}, name string) bool {
|
||||
var count int64
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
constraint, chk, table := m.GuessConstraintAndTable(stmt, name)
|
||||
if constraint != nil {
|
||||
name = constraint.Name
|
||||
} else if chk != nil {
|
||||
name = chk.Name
|
||||
}
|
||||
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
|
||||
"table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
|
||||
).Row().Scan(&count)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) CurrentDatabase() (name string) {
|
||||
var null interface{}
|
||||
m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
|
||||
for _, opt := range opts {
|
||||
str := stmt.Quote(opt.DBName)
|
||||
if opt.Expression != "" {
|
||||
str = opt.Expression
|
||||
}
|
||||
|
||||
if opt.Collate != "" {
|
||||
str += " COLLATE " + opt.Collate
|
||||
}
|
||||
|
||||
if opt.Sort != "" {
|
||||
str += " " + opt.Sort
|
||||
}
|
||||
results = append(results, clause.Expr{SQL: str})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) CreateIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
opts := m.BuildIndexOptions(idx.Fields, stmt)
|
||||
values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
|
||||
|
||||
createIndexSQL := "CREATE "
|
||||
if idx.Class != "" {
|
||||
createIndexSQL += idx.Class + " "
|
||||
}
|
||||
createIndexSQL += "INDEX ?"
|
||||
|
||||
if idx.Type != "" {
|
||||
createIndexSQL += " USING " + idx.Type
|
||||
}
|
||||
createIndexSQL += " ON ??"
|
||||
|
||||
if idx.Where != "" {
|
||||
createIndexSQL += " WHERE " + idx.Where
|
||||
}
|
||||
|
||||
return m.DB.Exec(createIndexSQL, values...).Error
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to create index with name %v", name)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) HasIndex(value interface{}, name string) bool {
|
||||
var count int
|
||||
m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
m.DB.Raw(
|
||||
"SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
|
||||
).Row().Scan(&count)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return count > 0
|
||||
}
|
||||
|
||||
func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
var sql string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
|
||||
if sql != "" {
|
||||
return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
|
||||
}
|
||||
return fmt.Errorf("failed to find index with name %v", oldName)
|
||||
})
|
||||
}
|
||||
|
||||
func (m Migrator) DropIndex(value interface{}, name string) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
if stmt.Schema != nil {
|
||||
if idx := stmt.Schema.LookIndex(name); idx != nil {
|
||||
name = idx.Name
|
||||
}
|
||||
}
|
||||
|
||||
return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
|
||||
})
|
||||
}
|
||||
|
||||
func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) {
|
||||
sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??"
|
||||
if constraint.OnDelete != "" {
|
||||
sql += " ON DELETE " + constraint.OnDelete
|
||||
}
|
||||
|
||||
if constraint.OnUpdate != "" {
|
||||
sql += " ON UPDATE " + constraint.OnUpdate
|
||||
}
|
||||
|
||||
var foreignKeys, references []interface{}
|
||||
for _, field := range constraint.ForeignKeys {
|
||||
foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName})
|
||||
}
|
||||
|
||||
for _, field := range constraint.References {
|
||||
references = append(references, clause.Column{Name: field.DBName})
|
||||
}
|
||||
results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references)
|
||||
return
|
||||
}
|
||||
|
||||
func (m Migrator) getRawDDL(table string) (string, error) {
|
||||
var createSQL string
|
||||
m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
|
||||
|
||||
if m.DB.Error != nil {
|
||||
return "", m.DB.Error
|
||||
}
|
||||
return createSQL, nil
|
||||
}
|
||||
|
||||
func (m Migrator) recreateTable(value interface{}, tablePtr *string,
|
||||
getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error {
|
||||
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
|
||||
table := stmt.Table
|
||||
if tablePtr != nil {
|
||||
table = *tablePtr
|
||||
}
|
||||
|
||||
rawDDL, err := m.getRawDDL(table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newTableName := table + "__temp"
|
||||
|
||||
createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if createSQL == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName))
|
||||
|
||||
createDDL, err := parseDDL(createSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
columns := createDDL.getColumns()
|
||||
|
||||
return m.DB.Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := []string{
|
||||
fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
|
||||
fmt.Sprintf("DROP TABLE `%v`", table),
|
||||
fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
|
||||
}
|
||||
for _, query := range queries {
|
||||
if err := tx.Exec(query).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
219
gormlite/sqlite.go
Normal file
219
gormlite/sqlite.go
Normal file
@@ -0,0 +1,219 @@
|
||||
// Package gormlite provides a GORM driver for SQLite.
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/callbacks"
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/migrator"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/driver"
|
||||
)
|
||||
|
||||
type Dialector struct {
|
||||
DSN string
|
||||
Conn gorm.ConnPool
|
||||
}
|
||||
|
||||
func Open(dsn string) gorm.Dialector {
|
||||
return &Dialector{DSN: dsn}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Name() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (dialector Dialector) Initialize(db *gorm.DB) (err error) {
|
||||
if dialector.Conn != nil {
|
||||
db.ConnPool = dialector.Conn
|
||||
} else {
|
||||
conn, err := sql.Open("sqlite3", dialector.DSN)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.ConnPool = conn
|
||||
}
|
||||
|
||||
var version string
|
||||
if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil {
|
||||
return err
|
||||
}
|
||||
// https://www.sqlite.org/releaselog/3_35_0.html
|
||||
if compareVersion(version, "3.35.0") >= 0 {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"},
|
||||
UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"},
|
||||
DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"},
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
} else {
|
||||
callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{
|
||||
LastInsertIDReversed: true,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range dialector.ClauseBuilders() {
|
||||
db.ClauseBuilders[k] = v
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder {
|
||||
return map[string]clause.ClauseBuilder{
|
||||
"INSERT": func(c clause.Clause, builder clause.Builder) {
|
||||
if insert, ok := c.Expression.(clause.Insert); ok {
|
||||
if stmt, ok := builder.(*gorm.Statement); ok {
|
||||
stmt.WriteString("INSERT ")
|
||||
if insert.Modifier != "" {
|
||||
stmt.WriteString(insert.Modifier)
|
||||
stmt.WriteByte(' ')
|
||||
}
|
||||
|
||||
stmt.WriteString("INTO ")
|
||||
if insert.Table.Name == "" {
|
||||
stmt.WriteQuoted(stmt.Table)
|
||||
} else {
|
||||
stmt.WriteQuoted(insert.Table)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Build(builder)
|
||||
},
|
||||
"LIMIT": func(c clause.Clause, builder clause.Builder) {
|
||||
if limit, ok := c.Expression.(clause.Limit); ok {
|
||||
var lmt = -1
|
||||
if limit.Limit != nil && *limit.Limit >= 0 {
|
||||
lmt = *limit.Limit
|
||||
}
|
||||
if lmt >= 0 || limit.Offset > 0 {
|
||||
builder.WriteString("LIMIT ")
|
||||
builder.WriteString(strconv.Itoa(lmt))
|
||||
}
|
||||
if limit.Offset > 0 {
|
||||
builder.WriteString(" OFFSET ")
|
||||
builder.WriteString(strconv.Itoa(limit.Offset))
|
||||
}
|
||||
}
|
||||
},
|
||||
"FOR": func(c clause.Clause, builder clause.Builder) {
|
||||
if _, ok := c.Expression.(clause.Locking); ok {
|
||||
// SQLite3 does not support row-level locking.
|
||||
return
|
||||
}
|
||||
c.Build(builder)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression {
|
||||
if field.AutoIncrement {
|
||||
return clause.Expr{SQL: "NULL"}
|
||||
}
|
||||
|
||||
// doesn't work, will raise error
|
||||
return clause.Expr{SQL: "DEFAULT"}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator {
|
||||
return Migrator{migrator.Migrator{Config: migrator.Config{
|
||||
DB: db,
|
||||
Dialector: dialector,
|
||||
CreateIndexAfterCreateTable: true,
|
||||
}}}
|
||||
}
|
||||
|
||||
func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) {
|
||||
writer.WriteByte('?')
|
||||
}
|
||||
|
||||
func (dialector Dialector) QuoteTo(writer clause.Writer, str string) {
|
||||
writer.WriteByte('`')
|
||||
if strings.Contains(str, ".") {
|
||||
for idx, str := range strings.Split(str, ".") {
|
||||
if idx > 0 {
|
||||
writer.WriteString(".`")
|
||||
}
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
} else {
|
||||
writer.WriteString(str)
|
||||
writer.WriteByte('`')
|
||||
}
|
||||
}
|
||||
|
||||
func (dialector Dialector) Explain(sql string, vars ...interface{}) string {
|
||||
return logger.ExplainSQL(sql, nil, `"`, vars...)
|
||||
}
|
||||
|
||||
func (dialector Dialector) DataTypeOf(field *schema.Field) string {
|
||||
switch field.DataType {
|
||||
case schema.Bool:
|
||||
return "numeric"
|
||||
case schema.Int, schema.Uint:
|
||||
if field.AutoIncrement && !field.PrimaryKey {
|
||||
// https://www.sqlite.org/autoinc.html
|
||||
return "integer PRIMARY KEY AUTOINCREMENT"
|
||||
} else {
|
||||
return "integer"
|
||||
}
|
||||
case schema.Float:
|
||||
return "real"
|
||||
case schema.String:
|
||||
return "text"
|
||||
case schema.Time:
|
||||
// Distinguish between schema.Time and tag time
|
||||
if val, ok := field.TagSettings["TYPE"]; ok {
|
||||
return val
|
||||
} else {
|
||||
return "datetime"
|
||||
}
|
||||
case schema.Bytes:
|
||||
return "blob"
|
||||
}
|
||||
|
||||
return string(field.DataType)
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error {
|
||||
tx.Exec("SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error {
|
||||
tx.Exec("ROLLBACK TO SAVEPOINT " + name)
|
||||
return nil
|
||||
}
|
||||
|
||||
func compareVersion(version1, version2 string) int {
|
||||
n, m := len(version1), len(version2)
|
||||
i, j := 0, 0
|
||||
for i < n || j < m {
|
||||
x := 0
|
||||
for ; i < n && version1[i] != '.'; i++ {
|
||||
x = x*10 + int(version1[i]-'0')
|
||||
}
|
||||
i++
|
||||
y := 0
|
||||
for ; j < m && version2[j] != '.'; j++ {
|
||||
y = y*10 + int(version2[j]-'0')
|
||||
}
|
||||
j++
|
||||
if x > y {
|
||||
return 1
|
||||
}
|
||||
if x < y {
|
||||
return -1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
64
gormlite/sqlite_test.go
Normal file
64
gormlite/sqlite_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package gormlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
_ "github.com/ncruces/go-sqlite3/embed"
|
||||
)
|
||||
|
||||
func TestDialector(t *testing.T) {
|
||||
// This is the DSN of the in-memory SQLite database for these tests.
|
||||
const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared"
|
||||
|
||||
rows := []struct {
|
||||
description string
|
||||
dialector *Dialector
|
||||
openSuccess bool
|
||||
query string
|
||||
querySuccess bool
|
||||
}{
|
||||
{
|
||||
description: "Default driver",
|
||||
dialector: &Dialector{
|
||||
DSN: InMemoryDSN,
|
||||
},
|
||||
openSuccess: true,
|
||||
query: "SELECT 1",
|
||||
querySuccess: true,
|
||||
},
|
||||
}
|
||||
for rowIndex, row := range rows {
|
||||
t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) {
|
||||
db, err := gorm.Open(row.dialector, &gorm.Config{})
|
||||
if !row.openSuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected Open to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected Open to succeed; got error: %v", err)
|
||||
}
|
||||
if db == nil {
|
||||
t.Errorf("Expected db to be non-nil.")
|
||||
}
|
||||
if row.query != "" {
|
||||
err = db.Exec(row.query).Error
|
||||
if !row.querySuccess {
|
||||
if err == nil {
|
||||
t.Errorf("Expected query to fail.")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected query to succeed; got error: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user