From f07e82e3611d8f92f86158443c118da6d67861af Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Tue, 6 Jun 2023 12:37:54 +0100 Subject: [PATCH] GORM driver. --- README.md | 2 + go.mod | 6 + go.sum | 6 + gormlite/LICENSE | 22 ++ gormlite/README.md | 24 ++ gormlite/ddlmod.go | 231 +++++++++++++++++++ gormlite/ddlmod_test.go | 334 +++++++++++++++++++++++++++ gormlite/error_translator.go | 15 ++ gormlite/errors.go | 7 + gormlite/migrator.go | 428 +++++++++++++++++++++++++++++++++++ gormlite/sqlite.go | 219 ++++++++++++++++++ gormlite/sqlite_test.go | 64 ++++++ 12 files changed, 1358 insertions(+) create mode 100644 gormlite/LICENSE create mode 100644 gormlite/README.md create mode 100644 gormlite/ddlmod.go create mode 100644 gormlite/ddlmod_test.go create mode 100644 gormlite/error_translator.go create mode 100644 gormlite/errors.go create mode 100644 gormlite/migrator.go create mode 100644 gormlite/sqlite.go create mode 100644 gormlite/sqlite_test.go diff --git a/README.md b/README.md index 64d36ac..8a06a28 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ provides a [`database/sql`](https://pkg.go.dev/database/sql) driver embeds a build of SQLite into your application. - Package [`github.com/ncruces/go-sqlite3/vfs`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/vfs) wraps the [C SQLite VFS API](https://www.sqlite.org/vfs.html) and provides a pure Go implementation. +- Package [`github.com/ncruces/go-sqlite3/gormlite`](https://pkg.go.dev/github.com/ncruces/go-sqlite3/gormlite) +provides a [GORM](https://gorm.io) driver. ### Caveats diff --git a/go.mod b/go.mod index 0371334..50d12f7 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,12 @@ require ( github.com/tetratelabs/wazero v1.2.0 golang.org/x/sync v0.2.0 golang.org/x/sys v0.8.0 + gorm.io/gorm v1.25.1 +) + +require ( + github.com/jinzhu/inflection v1.0.0 // indirect + github.com/jinzhu/now v1.1.5 // indirect ) retract v0.4.0 // tagged from the wrong branch diff --git a/go.sum b/go.sum index 5799fae..78018fa 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk= github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= github.com/psanford/httpreadat v0.1.0 h1:VleW1HS2zO7/4c7c7zNl33fO6oYACSagjJIyMIwZLUE= @@ -8,3 +12,5 @@ golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gorm.io/gorm v1.25.1 h1:nsSALe5Pr+cM3V1qwwQ7rOkw+6UeLrX5O4v3llhHa64= +gorm.io/gorm v1.25.1/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= diff --git a/gormlite/LICENSE b/gormlite/LICENSE new file mode 100644 index 0000000..558c4ff --- /dev/null +++ b/gormlite/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2023 Nuno Cruces +Copyright (c) 2023 Jinzhu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/gormlite/README.md b/gormlite/README.md new file mode 100644 index 0000000..94c01b7 --- /dev/null +++ b/gormlite/README.md @@ -0,0 +1,24 @@ +# GORM SQLite Driver + +## Usage + +```go +import ( + _ "github.com/ncruces/go-sqlite3/embed" + "github.com/ncruces/go-sqlite3/gormlite" + "gorm.io/gorm" +) + +db, err := gorm.Open(gormlite.Open("gorm.db"), &gorm.Config{}) +``` + +Checkout [https://gorm.io](https://gorm.io) for details. + +### Foreign-key constraint activation + +Foreign-key constraint is disabled by default in SQLite. To activate it, use connection URL parameter: +```go +db, err := gorm.Open(gormlite.Open( + "file:gorm.db?_pragma=busy_timeout(10000)&_pragma=locking_mode(normal)&_pragma=foreign_keys(1)"), + &gorm.Config{}) +``` \ No newline at end of file diff --git a/gormlite/ddlmod.go b/gormlite/ddlmod.go new file mode 100644 index 0000000..3e35b45 --- /dev/null +++ b/gormlite/ddlmod.go @@ -0,0 +1,231 @@ +package gormlite + +import ( + "database/sql" + "errors" + "fmt" + "regexp" + "strconv" + "strings" + + "gorm.io/gorm/migrator" +) + +var ( + sqliteSeparator = "`|\"|'|\t" + indexRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)CREATE(?: UNIQUE)? INDEX [%v]?[\w\d-]+[%v]? ON (.*)$`, sqliteSeparator, sqliteSeparator)) + tableRegexp = regexp.MustCompile(fmt.Sprintf(`(?is)(CREATE TABLE [%v]?[\w\d-]+[%v]?)(?:\s*\((.*)\))?`, sqliteSeparator, sqliteSeparator)) + separatorRegexp = regexp.MustCompile(fmt.Sprintf("[%v]", sqliteSeparator)) + columnsRegexp = regexp.MustCompile(fmt.Sprintf(`[(,][%v]?(\w+)[%v]?`, sqliteSeparator, sqliteSeparator)) + columnRegexp = regexp.MustCompile(fmt.Sprintf(`^[%v]?([\w\d]+)[%v]?\s+([\w\(\)\d]+)(.*)$`, sqliteSeparator, sqliteSeparator)) + defaultValueRegexp = regexp.MustCompile(`(?i) DEFAULT \(?(.+)?\)?( |COLLATE|GENERATED|$)`) + regRealDataType = regexp.MustCompile(`[^\d](\d+)[^\d]?`) +) + +func getAllColumns(s string) []string { + allMatches := columnsRegexp.FindAllStringSubmatch(s, -1) + columns := make([]string, 0, len(allMatches)) + for _, matches := range allMatches { + if len(matches) > 1 { + columns = append(columns, matches[1]) + } + } + return columns +} + +type ddl struct { + head string + fields []string + columns []migrator.ColumnType +} + +func parseDDL(strs ...string) (*ddl, error) { + var result ddl + for _, str := range strs { + if sections := tableRegexp.FindStringSubmatch(str); len(sections) > 0 { + var ( + ddlBody = sections[2] + ddlBodyRunes = []rune(ddlBody) + bracketLevel int + quote rune + buf string + ) + ddlBodyRunesLen := len(ddlBodyRunes) + + result.head = sections[1] + + for idx := 0; idx < ddlBodyRunesLen; idx++ { + var ( + next rune = 0 + c = ddlBodyRunes[idx] + ) + if idx+1 < ddlBodyRunesLen { + next = ddlBodyRunes[idx+1] + } + + if sc := string(c); separatorRegexp.MatchString(sc) { + if c == next { + buf += sc // Skip escaped quote + idx++ + } else if quote > 0 { + quote = 0 + } else { + quote = c + } + } else if quote == 0 { + if c == '(' { + bracketLevel++ + } else if c == ')' { + bracketLevel-- + } else if bracketLevel == 0 { + if c == ',' { + result.fields = append(result.fields, strings.TrimSpace(buf)) + buf = "" + continue + } + } + } + + if bracketLevel < 0 { + return nil, errors.New("invalid DDL, unbalanced brackets") + } + + buf += string(c) + } + + if bracketLevel != 0 { + return nil, errors.New("invalid DDL, unbalanced brackets") + } + + if buf != "" { + result.fields = append(result.fields, strings.TrimSpace(buf)) + } + + for _, f := range result.fields { + fUpper := strings.ToUpper(f) + if strings.HasPrefix(fUpper, "CHECK") || + strings.HasPrefix(fUpper, "CONSTRAINT") { + continue + } + + if strings.HasPrefix(fUpper, "PRIMARY KEY") { + for _, name := range getAllColumns(f) { + for idx, column := range result.columns { + if column.NameValue.String == name { + column.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + result.columns[idx] = column + break + } + } + } + } else if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 0 { + columnType := migrator.ColumnType{ + NameValue: sql.NullString{String: matches[1], Valid: true}, + DataTypeValue: sql.NullString{String: matches[2], Valid: true}, + ColumnTypeValue: sql.NullString{String: matches[2], Valid: true}, + PrimaryKeyValue: sql.NullBool{Valid: true}, + UniqueValue: sql.NullBool{Valid: true}, + NullableValue: sql.NullBool{Valid: true}, + DefaultValueValue: sql.NullString{Valid: false}, + } + + matchUpper := strings.ToUpper(matches[3]) + if strings.Contains(matchUpper, " NOT NULL") { + columnType.NullableValue = sql.NullBool{Bool: false, Valid: true} + } else if strings.Contains(matchUpper, " NULL") { + columnType.NullableValue = sql.NullBool{Bool: true, Valid: true} + } + if strings.Contains(matchUpper, " UNIQUE") { + columnType.UniqueValue = sql.NullBool{Bool: true, Valid: true} + } + if strings.Contains(matchUpper, " PRIMARY") { + columnType.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} + } + if defaultMatches := defaultValueRegexp.FindStringSubmatch(matches[3]); len(defaultMatches) > 1 { + if strings.ToLower(defaultMatches[1]) != "null" { + columnType.DefaultValueValue = sql.NullString{String: strings.Trim(defaultMatches[1], `"`), Valid: true} + } + } + + // data type length + matches := regRealDataType.FindAllStringSubmatch(columnType.DataTypeValue.String, -1) + if len(matches) == 1 && len(matches[0]) == 2 { + size, _ := strconv.Atoi(matches[0][1]) + columnType.LengthValue = sql.NullInt64{Valid: true, Int64: int64(size)} + columnType.DataTypeValue.String = strings.TrimSuffix(columnType.DataTypeValue.String, matches[0][0]) + } + + result.columns = append(result.columns, columnType) + } + } + } else if matches := indexRegexp.FindStringSubmatch(str); len(matches) > 0 { + for _, column := range getAllColumns(matches[1]) { + for idx, c := range result.columns { + if c.NameValue.String == column { + c.UniqueValue = sql.NullBool{Bool: true, Valid: true} + result.columns[idx] = c + } + } + } + } else { + return nil, errors.New("invalid DDL") + } + } + + return &result, nil +} + +func (d *ddl) compile() string { + if len(d.fields) == 0 { + return d.head + } + + return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ",")) +} + +func (d *ddl) addConstraint(name string, sql string) { + reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields[i] = sql + return + } + } + + d.fields = append(d.fields, sql) +} + +func (d *ddl) removeConstraint(name string) bool { + reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields = append(d.fields[:i], d.fields[i+1:]...) + return true + } + } + return false +} + +func (d *ddl) getColumns() []string { + res := []string{} + + for _, f := range d.fields { + fUpper := strings.ToUpper(f) + if strings.HasPrefix(fUpper, "PRIMARY KEY") || + strings.HasPrefix(fUpper, "CHECK") || + strings.HasPrefix(fUpper, "CONSTRAINT") || + strings.Contains(fUpper, "GENERATED ALWAYS AS") { + continue + } + + reg := regexp.MustCompile("^[\"`']?([\\w\\d]+)[\"`']?") + match := reg.FindStringSubmatch(f) + + if match != nil { + res = append(res, "`"+match[1]+"`") + } + } + return res +} diff --git a/gormlite/ddlmod_test.go b/gormlite/ddlmod_test.go new file mode 100644 index 0000000..62b3763 --- /dev/null +++ b/gormlite/ddlmod_test.go @@ -0,0 +1,334 @@ +package gormlite + +import ( + "database/sql" + "testing" + + "gorm.io/gorm/migrator" + "gorm.io/gorm/utils/tests" +) + +func TestParseDDL(t *testing.T) { + params := []struct { + name string + sql []string + nFields int + columns []migrator.ColumnType + }{ + {"with_fk", []string{ + "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500) DEFAULT \"hello\",`age` integer DEFAULT 18,`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))", + "CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)", + }, 6, []migrator.ColumnType{ + {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + }, + }, + {"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{ + {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + }}, + {"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{ + {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + }, + }, + {"no brackets", []string{"create table test"}, 0, nil}, + {"with_special_characters", []string{ + "CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")", + }, 1, []migrator.ColumnType{ + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + }, + }, + { + "table_name_with_dash", + []string{ + "CREATE TABLE `test-a` (`id` int NOT NULL)", + "CREATE UNIQUE INDEX `idx_test-a_id` ON `test-a`(`id`)", + }, + 1, + []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "id", Valid: true}, + DataTypeValue: sql.NullString{String: "int", Valid: true}, + ColumnTypeValue: sql.NullString{String: "int", Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + DefaultValueValue: sql.NullString{Valid: false}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + PrimaryKeyValue: sql.NullBool{Valid: true}, + }, + }, + }, + { + "unique index", + []string{ + "CREATE TABLE `test-b` (`field` integer NOT NULL)", + "CREATE UNIQUE INDEX `idx_uq` ON `test-b`(`field`) WHERE field = 0", + }, + 1, + []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "field", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + }, + }, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + ddl, err := parseDDL(p.sql...) + + if err != nil { + panic(err.Error()) + } + + tests.AssertEqual(t, p.sql[0], ddl.compile()) + if len(ddl.fields) != p.nFields { + t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields)) + } + tests.AssertEqual(t, ddl.columns, p.columns) + }) + } +} + +func TestParseDDL_Whitespaces(t *testing.T) { + testColumns := []migrator.ColumnType{ + { + NameValue: sql.NullString{String: "id", Valid: true}, + DataTypeValue: sql.NullString{String: "integer", Valid: true}, + ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, + NullableValue: sql.NullBool{Bool: false, Valid: true}, + DefaultValueValue: sql.NullString{Valid: false}, + UniqueValue: sql.NullBool{Bool: true, Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, + }, + { + NameValue: sql.NullString{String: "dark_mode", Valid: true}, + DataTypeValue: sql.NullString{String: "numeric", Valid: true}, + ColumnTypeValue: sql.NullString{String: "numeric", Valid: true}, + NullableValue: sql.NullBool{Valid: true}, + DefaultValueValue: sql.NullString{String: "true", Valid: true}, + UniqueValue: sql.NullBool{Bool: false, Valid: true}, + PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, + }, + } + + params := []struct { + name string + sql []string + nFields int + columns []migrator.ColumnType + }{ + { + "with_newline", + []string{"CREATE TABLE `users`\n(\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + { + "with_newline_2", + []string{"CREATE TABLE `users` (\n\nid integer primary key unique,\ndark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + { + "with_missing_space", + []string{"CREATE TABLE `users`(id integer primary key unique, dark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + { + "with_many_spaces", + []string{"CREATE TABLE `users` (id integer primary key unique, dark_mode numeric DEFAULT true)"}, + 2, + testColumns, + }, + } + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + ddl, err := parseDDL(p.sql...) + + if err != nil { + panic(err.Error()) + } + + if len(ddl.fields) != p.nFields { + t.Fatalf("fields length doesn't match: expect: %v, got %v", p.nFields, len(ddl.fields)) + } + tests.AssertEqual(t, ddl.columns, p.columns) + }) + } +} + +func TestParseDDL_error(t *testing.T) { + params := []struct { + name string + sql string + }{ + {"invalid_cmd", "CREATE TABLE"}, + {"unbalanced_brackets", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)"}, + {"unbalanced_brackets2", "CREATE TABLE test (ID int NOT NULL,Name varchar(255)))"}, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + _, err := parseDDL(p.sql) + if err == nil { + t.Fail() + } + }) + } +} + +func TestAddConstraint(t *testing.T) { + params := []struct { + name string + fields []string + cName string + sql string + expect []string + }{ + { + name: "add_new", + fields: []string{"`id` integer NOT NULL"}, + cName: "fk_users_notes", + sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))", + expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"}, + }, + { + name: "update", + fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"}, + cName: "fk_users_notes", + sql: "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE", + expect: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`)) ON UPDATE CASCADE ON DELETE CASCADE"}, + }, + { + name: "add_check", + fields: []string{"`id` integer NOT NULL"}, + cName: "name_checker", + sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')", + expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"}, + }, + { + name: "update_check", + fields: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')"}, + cName: "name_checker", + sql: "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')", + expect: []string{"`id` integer NOT NULL", "CONSTRAINT `name_checker` CHECK (`name` <> 'jinzhu')"}, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + testDDL := ddl{fields: p.fields} + + testDDL.addConstraint(p.cName, p.sql) + tests.AssertEqual(t, p.expect, testDDL.fields) + }) + } +} + +func TestRemoveConstraint(t *testing.T) { + params := []struct { + name string + fields []string + cName string + success bool + expect []string + }{ + { + name: "fk", + fields: []string{"`id` integer NOT NULL", "CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))"}, + cName: "fk_users_notes", + success: true, + expect: []string{"`id` integer NOT NULL"}, + }, + { + name: "check", + fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"}, + cName: "name_checker", + success: true, + expect: []string{"`id` integer NOT NULL"}, + }, + { + name: "none", + fields: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"}, + cName: "nothing", + success: false, + expect: []string{"CONSTRAINT `name_checker` CHECK (`name` <> 'thetadev')", "`id` integer NOT NULL"}, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + testDDL := ddl{fields: p.fields} + + success := testDDL.removeConstraint(p.cName) + + tests.AssertEqual(t, p.success, success) + tests.AssertEqual(t, p.expect, testDDL.fields) + }) + } +} + +func TestGetColumns(t *testing.T) { + params := []struct { + name string + ddl string + columns []string + }{ + { + name: "with_fk", + ddl: "CREATE TABLE `notes` (`id` integer NOT NULL,`text` varchar(500),`user_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_users_notes` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`))", + columns: []string{"`id`", "`text`", "`user_id`"}, + }, + { + name: "with_check", + ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName!='John'))", + columns: []string{"`ID`", "`LastName`", "`FirstName`", "`Age`"}, + }, + { + name: "with_escaped_quote", + ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL DEFAULT \"\",FirstName varchar(255))", + columns: []string{"`ID`", "`LastName`", "`FirstName`"}, + }, + { + name: "with_generated_column", + ddl: "CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),FullName varchar(255) GENERATED ALWAYS AS (FirstName || ' ' || LastName))", + columns: []string{"`ID`", "`LastName`", "`FirstName`"}, + }, + { + name: "with_new_line", + ddl: `CREATE TABLE "tb_sys_role_menu__temp" ( + "id" integer PRIMARY KEY AUTOINCREMENT, + "created_at" datetime NOT NULL, + "updated_at" datetime NOT NULL, + "created_by" integer NOT NULL DEFAULT 0, + "updated_by" integer NOT NULL DEFAULT 0, + "role_id" integer NOT NULL, + "menu_id" bigint NOT NULL +)`, + columns: []string{"`id`", "`created_at`", "`updated_at`", "`created_by`", "`updated_by`", "`role_id`", "`menu_id`"}, + }, + } + + for _, p := range params { + t.Run(p.name, func(t *testing.T) { + testDDL, err := parseDDL(p.ddl) + if err != nil { + panic(err.Error()) + } + + cols := testDDL.getColumns() + + tests.AssertEqual(t, p.columns, cols) + }) + } +} diff --git a/gormlite/error_translator.go b/gormlite/error_translator.go new file mode 100644 index 0000000..46681eb --- /dev/null +++ b/gormlite/error_translator.go @@ -0,0 +1,15 @@ +package gormlite + +import ( + "errors" + + "github.com/ncruces/go-sqlite3" + "gorm.io/gorm" +) + +func (dialector Dialector) Translate(err error) error { + if errors.Is(err, sqlite3.CONSTRAINT_UNIQUE) { + return gorm.ErrDuplicatedKey + } + return err +} diff --git a/gormlite/errors.go b/gormlite/errors.go new file mode 100644 index 0000000..2032658 --- /dev/null +++ b/gormlite/errors.go @@ -0,0 +1,7 @@ +package gormlite + +import "errors" + +var ( + ErrConstraintsNotImplemented = errors.New("constraints not implemented on sqlite, consider using DisableForeignKeyConstraintWhenMigrating, more details https://github.com/go-gorm/gorm/wiki/GORM-V2-Release-Note-Draft#all-new-migrator") +) diff --git a/gormlite/migrator.go b/gormlite/migrator.go new file mode 100644 index 0000000..705ceb7 --- /dev/null +++ b/gormlite/migrator.go @@ -0,0 +1,428 @@ +package gormlite + +import ( + "database/sql" + "fmt" + "regexp" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +type Migrator struct { + migrator.Migrator +} + +func (m *Migrator) RunWithoutForeignKey(fc func() error) error { + var enabled int + m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled) + if enabled == 1 { + m.DB.Exec("PRAGMA foreign_keys = OFF") + defer m.DB.Exec("PRAGMA foreign_keys = ON") + } + + return fc() +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) DropTable(values ...interface{}) error { + return m.RunWithoutForeignKey(func() error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + + return nil + }) +} + +func (m Migrator) GetTables() (tableList []string, err error) { + return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error +} + +func (m Migrator) HasColumn(value interface{}, name string) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + } + + if name != "" { + m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%", + ).Row().Scan(&count) + } + return nil + }) + return count > 0 +} + +func (m Migrator) AlterColumn(value interface{}, name string) error { + return m.RunWithoutForeignKey(func() error { + return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + if field := stmt.Schema.LookUpField(name); field != nil { + // lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)` + reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)") + if err != nil { + return "", nil, err + } + + createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName)) + + if createSQL == rawDDL { + return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL) + } + + return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil + } + return "", nil, fmt.Errorf("failed to alter field with name %v", name) + }) + }) +} + +// ColumnTypes return columnTypes []gorm.ColumnType and execErr error +func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { + columnTypes := make([]gorm.ColumnType, 0) + execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { + var ( + sqls []string + sqlDDL *ddl + ) + + if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil { + return err + } + + if sqlDDL, err = parseDDL(sqls...); err != nil { + return err + } + + rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() + if err != nil { + return err + } + defer func() { + err = rows.Close() + }() + + var rawColumnTypes []*sql.ColumnType + rawColumnTypes, err = rows.ColumnTypes() + if err != nil { + return err + } + + for _, c := range rawColumnTypes { + columnType := migrator.ColumnType{SQLColumnType: c} + for _, column := range sqlDDL.columns { + if column.NameValue.String == c.Name() { + column.SQLColumnType = c + columnType = column + break + } + } + columnTypes = append(columnTypes, columnType) + } + + return err + }) + + return columnTypes, execErr +} + +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,") + if err != nil { + return "", nil, err + } + + createSQL := reg.ReplaceAllString(rawDDL, "") + + return createSQL, nil, nil + }) +} + +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + + return m.recreateTable(value, &table, + func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + var ( + constraintName string + constraintSql string + constraintValues []interface{} + ) + + if constraint != nil { + constraintName = constraint.Name + constraintSql, constraintValues = buildConstraint(constraint) + } else if chk != nil { + constraintName = chk.Name + constraintSql = "CONSTRAINT ? CHECK (?)" + constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} + } else { + return "", nil, nil + } + + createDDL, err := parseDDL(rawDDL) + if err != nil { + return "", nil, err + } + createDDL.addConstraint(constraintName, constraintSql) + createSQL := createDDL.compile() + + return createSQL, constraintValues, nil + }) + }) +} + +func (m Migrator) DropConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + + return m.recreateTable(value, &table, + func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + createDDL, err := parseDDL(rawDDL) + if err != nil { + return "", nil, err + } + createDDL.removeConstraint(name) + createSQL := createDDL.compile() + + return createSQL, nil, nil + }) + }) +} + +func (m Migrator) HasConstraint(value interface{}, name string) bool { + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, chk, table := m.GuessConstraintAndTable(stmt, name) + if constraint != nil { + name = constraint.Name + } else if chk != nil { + name = chk.Name + } + + m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + "table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%", + ).Row().Scan(&count) + + return nil + }) + + return count > 0 +} + +func (m Migrator) CurrentDatabase() (name string) { + var null interface{} + m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) + return +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + } + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + } + + if name != "" { + m.DB.Raw( + "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, + ).Row().Scan(&count) + } + return nil + }) + return count > 0 +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + var sql string + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) + if sql != "" { + return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error + } + return fmt.Errorf("failed to find index with name %v", oldName) + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if stmt.Schema != nil { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error + }) +} + +func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { + sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" + if constraint.OnDelete != "" { + sql += " ON DELETE " + constraint.OnDelete + } + + if constraint.OnUpdate != "" { + sql += " ON UPDATE " + constraint.OnUpdate + } + + var foreignKeys, references []interface{} + for _, field := range constraint.ForeignKeys { + foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) + } + + for _, field := range constraint.References { + references = append(references, clause.Column{Name: field.DBName}) + } + results = append(results, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) + return +} + +func (m Migrator) getRawDDL(table string) (string, error) { + var createSQL string + m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL) + + if m.DB.Error != nil { + return "", m.DB.Error + } + return createSQL, nil +} + +func (m Migrator) recreateTable(value interface{}, tablePtr *string, + getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + table := stmt.Table + if tablePtr != nil { + table = *tablePtr + } + + rawDDL, err := m.getRawDDL(table) + if err != nil { + return err + } + + newTableName := table + "__temp" + + createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt) + if err != nil { + return err + } + if createSQL == "" { + return nil + } + + tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*") + if err != nil { + return err + } + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + + createDDL, err := parseDDL(createSQL) + if err != nil { + return err + } + columns := createDDL.getColumns() + + return m.DB.Transaction(func(tx *gorm.DB) error { + if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil { + return err + } + + queries := []string{ + fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table), + fmt.Sprintf("DROP TABLE `%v`", table), + fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table), + } + for _, query := range queries { + if err := tx.Exec(query).Error; err != nil { + return err + } + } + return nil + }) + }) +} diff --git a/gormlite/sqlite.go b/gormlite/sqlite.go new file mode 100644 index 0000000..837d57c --- /dev/null +++ b/gormlite/sqlite.go @@ -0,0 +1,219 @@ +// Package gormlite provides a GORM driver for SQLite. +package gormlite + +import ( + "context" + "database/sql" + "strconv" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" + + _ "github.com/ncruces/go-sqlite3/driver" +) + +type Dialector struct { + DSN string + Conn gorm.ConnPool +} + +func Open(dsn string) gorm.Dialector { + return &Dialector{DSN: dsn} +} + +func (dialector Dialector) Name() string { + return "sqlite" +} + +func (dialector Dialector) Initialize(db *gorm.DB) (err error) { + if dialector.Conn != nil { + db.ConnPool = dialector.Conn + } else { + conn, err := sql.Open("sqlite3", dialector.DSN) + if err != nil { + return err + } + db.ConnPool = conn + } + + var version string + if err := db.ConnPool.QueryRowContext(context.Background(), "select sqlite_version()").Scan(&version); err != nil { + return err + } + // https://www.sqlite.org/releaselog/3_35_0.html + if compareVersion(version, "3.35.0") >= 0 { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, + LastInsertIDReversed: true, + }) + } else { + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + LastInsertIDReversed: true, + }) + } + + for k, v := range dialector.ClauseBuilders() { + db.ClauseBuilders[k] = v + } + return +} + +func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + return map[string]clause.ClauseBuilder{ + "INSERT": func(c clause.Clause, builder clause.Builder) { + if insert, ok := c.Expression.(clause.Insert); ok { + if stmt, ok := builder.(*gorm.Statement); ok { + stmt.WriteString("INSERT ") + if insert.Modifier != "" { + stmt.WriteString(insert.Modifier) + stmt.WriteByte(' ') + } + + stmt.WriteString("INTO ") + if insert.Table.Name == "" { + stmt.WriteQuoted(stmt.Table) + } else { + stmt.WriteQuoted(insert.Table) + } + return + } + } + + c.Build(builder) + }, + "LIMIT": func(c clause.Clause, builder clause.Builder) { + if limit, ok := c.Expression.(clause.Limit); ok { + var lmt = -1 + if limit.Limit != nil && *limit.Limit >= 0 { + lmt = *limit.Limit + } + if lmt >= 0 || limit.Offset > 0 { + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(lmt)) + } + if limit.Offset > 0 { + builder.WriteString(" OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + } + } + }, + "FOR": func(c clause.Clause, builder clause.Builder) { + if _, ok := c.Expression.(clause.Locking); ok { + // SQLite3 does not support row-level locking. + return + } + c.Build(builder) + }, + } +} + +func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { + if field.AutoIncrement { + return clause.Expr{SQL: "NULL"} + } + + // doesn't work, will raise error + return clause.Expr{SQL: "DEFAULT"} +} + +func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: dialector, + CreateIndexAfterCreateTable: true, + }}} +} + +func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') +} + +func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + if strings.Contains(str, ".") { + for idx, str := range strings.Split(str, ".") { + if idx > 0 { + writer.WriteString(".`") + } + writer.WriteString(str) + writer.WriteByte('`') + } + } else { + writer.WriteString(str) + writer.WriteByte('`') + } +} + +func (dialector Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + +func (dialector Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "numeric" + case schema.Int, schema.Uint: + if field.AutoIncrement && !field.PrimaryKey { + // https://www.sqlite.org/autoinc.html + return "integer PRIMARY KEY AUTOINCREMENT" + } else { + return "integer" + } + case schema.Float: + return "real" + case schema.String: + return "text" + case schema.Time: + // Distinguish between schema.Time and tag time + if val, ok := field.TagSettings["TYPE"]; ok { + return val + } else { + return "datetime" + } + case schema.Bytes: + return "blob" + } + + return string(field.DataType) +} + +func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error { + tx.Exec("SAVEPOINT " + name) + return nil +} + +func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error { + tx.Exec("ROLLBACK TO SAVEPOINT " + name) + return nil +} + +func compareVersion(version1, version2 string) int { + n, m := len(version1), len(version2) + i, j := 0, 0 + for i < n || j < m { + x := 0 + for ; i < n && version1[i] != '.'; i++ { + x = x*10 + int(version1[i]-'0') + } + i++ + y := 0 + for ; j < m && version2[j] != '.'; j++ { + y = y*10 + int(version2[j]-'0') + } + j++ + if x > y { + return 1 + } + if x < y { + return -1 + } + } + return 0 +} diff --git a/gormlite/sqlite_test.go b/gormlite/sqlite_test.go new file mode 100644 index 0000000..dc917c7 --- /dev/null +++ b/gormlite/sqlite_test.go @@ -0,0 +1,64 @@ +package gormlite + +import ( + "fmt" + "testing" + + "gorm.io/gorm" + + _ "github.com/ncruces/go-sqlite3/embed" +) + +func TestDialector(t *testing.T) { + // This is the DSN of the in-memory SQLite database for these tests. + const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared" + + rows := []struct { + description string + dialector *Dialector + openSuccess bool + query string + querySuccess bool + }{ + { + description: "Default driver", + dialector: &Dialector{ + DSN: InMemoryDSN, + }, + openSuccess: true, + query: "SELECT 1", + querySuccess: true, + }, + } + for rowIndex, row := range rows { + t.Run(fmt.Sprintf("%d/%s", rowIndex, row.description), func(t *testing.T) { + db, err := gorm.Open(row.dialector, &gorm.Config{}) + if !row.openSuccess { + if err == nil { + t.Errorf("Expected Open to fail.") + } + return + } + + if err != nil { + t.Errorf("Expected Open to succeed; got error: %v", err) + } + if db == nil { + t.Errorf("Expected db to be non-nil.") + } + if row.query != "" { + err = db.Exec(row.query).Error + if !row.querySuccess { + if err == nil { + t.Errorf("Expected query to fail.") + } + return + } + + if err != nil { + t.Errorf("Expected query to succeed; got error: %v", err) + } + } + }) + } +}