diff --git a/go.work.sum b/go.work.sum index 314b2b3..792a87e 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,4 +1,4 @@ -github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg= +github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08= golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= diff --git a/gormlite/ddlmod.go b/gormlite/ddlmod.go index 69cd179..e024b7a 100644 --- a/gormlite/ddlmod.go +++ b/gormlite/ddlmod.go @@ -125,7 +125,7 @@ func parseDDL(strs ...string) (*ddl, error) { ColumnTypeValue: sql.NullString{String: matches[2], Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, - NullableValue: sql.NullBool{Valid: true}, + NullableValue: sql.NullBool{Bool: true, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, } @@ -175,6 +175,18 @@ func parseDDL(strs ...string) (*ddl, error) { return &result, nil } +func (d *ddl) clone() *ddl { + copied := new(ddl) + *copied = *d + + copied.fields = make([]string, len(d.fields)) + copy(copied.fields, d.fields) + copied.columns = make([]migrator.ColumnType, len(d.columns)) + copy(copied.columns, d.columns) + + return copied +} + func (d *ddl) compile() string { if len(d.fields) == 0 { return d.head @@ -183,6 +195,21 @@ func (d *ddl) compile() string { return fmt.Sprintf("%s (%s)", d.head, strings.Join(d.fields, ",")) } +func (d *ddl) renameTable(dst, src string) error { + tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + regexp.QuoteMeta(src) + "\\b('|`|\")?\\s*") + if err != nil { + return err + } + + replaced := tableReg.ReplaceAllString(d.head, fmt.Sprintf(" `%s` ", dst)) + if replaced == d.head { + return fmt.Errorf("failed to look up tablename `%s` from DDL head '%s'", src, d.head) + } + + d.head = replaced + return nil +} + func (d *ddl) addConstraint(name string, sql string) { reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") @@ -208,6 +235,17 @@ func (d *ddl) removeConstraint(name string) bool { return false } +func (d *ddl) hasConstraint(name string) bool { + reg := regexp.MustCompile("^CONSTRAINT [\"`]?" + regexp.QuoteMeta(name) + "[\"` ]") + + for _, f := range d.fields { + if reg.MatchString(f) { + return true + } + } + return false +} + func (d *ddl) getColumns() []string { res := []string{} @@ -229,3 +267,30 @@ func (d *ddl) getColumns() []string { } return res } + +func (d *ddl) alterColumn(name, sql string) bool { + reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields[i] = sql + return false + } + } + + d.fields = append(d.fields, sql) + return true +} + +func (d *ddl) removeColumn(name string) bool { + reg := regexp.MustCompile("^(`|'|\"| )" + regexp.QuoteMeta(name) + "(`|'|\"| ) .*?$") + + for i := 0; i < len(d.fields); i++ { + if reg.MatchString(d.fields[i]) { + d.fields = append(d.fields[:i], d.fields[i+1:]...) + return true + } + } + + return false +} diff --git a/gormlite/ddlmod_test.go b/gormlite/ddlmod_test.go index 059adc1..d5eb993 100644 --- a/gormlite/ddlmod_test.go +++ b/gormlite/ddlmod_test.go @@ -20,16 +20,16 @@ func TestParseDDL(t *testing.T) { "CREATE UNIQUE INDEX `idx_profiles_refer` ON `profiles`(`text`)", }, 6, []migrator.ColumnType{ {NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}}, - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 500, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(500)", Valid: true}, DefaultValueValue: sql.NullString{String: "hello", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "age", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{String: "18", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "user_id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, }, {"with_check", []string{"CREATE TABLE Persons (ID int NOT NULL,LastName varchar(255) NOT NULL,FirstName varchar(255),Age int,CHECK (Age>=18),CHECK (FirstName<>'John'))"}, 6, []migrator.ColumnType{ {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, {NameValue: sql.NullString{String: "LastName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, - {NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "FirstName", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 255, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(255)", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "Age", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }}, {"lowercase", []string{"create table test (ID int NOT NULL)"}, 1, []migrator.ColumnType{ {NameValue: sql.NullString{String: "ID", Valid: true}, DataTypeValue: sql.NullString{String: "int", Valid: true}, ColumnTypeValue: sql.NullString{String: "int", Valid: true}, NullableValue: sql.NullBool{Bool: false, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, @@ -39,7 +39,7 @@ func TestParseDDL(t *testing.T) { {"with_special_characters", []string{ "CREATE TABLE `test` (`text` varchar(10) DEFAULT \"测试, \")", }, 1, []migrator.ColumnType{ - {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, + {NameValue: sql.NullString{String: "text", Valid: true}, DataTypeValue: sql.NullString{String: "varchar", Valid: true}, LengthValue: sql.NullInt64{Int64: 10, Valid: true}, ColumnTypeValue: sql.NullString{String: "varchar(10)", Valid: true}, DefaultValueValue: sql.NullString{String: "测试, ", Valid: true}, NullableValue: sql.NullBool{Bool: true, Valid: true}, UniqueValue: sql.NullBool{Valid: true}, PrimaryKeyValue: sql.NullBool{Valid: true}}, }, }, { @@ -122,7 +122,7 @@ func TestParseDDL_Whitespaces(t *testing.T) { NameValue: sql.NullString{String: "id", Valid: true}, DataTypeValue: sql.NullString{String: "integer", Valid: true}, ColumnTypeValue: sql.NullString{String: "integer", Valid: true}, - NullableValue: sql.NullBool{Bool: false, Valid: true}, + NullableValue: sql.NullBool{Bool: true, Valid: true}, DefaultValueValue: sql.NullString{Valid: false}, UniqueValue: sql.NullBool{Bool: true, Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: true, Valid: true}, @@ -131,7 +131,7 @@ func TestParseDDL_Whitespaces(t *testing.T) { NameValue: sql.NullString{String: "dark_mode", Valid: true}, DataTypeValue: sql.NullString{String: "numeric", Valid: true}, ColumnTypeValue: sql.NullString{String: "numeric", Valid: true}, - NullableValue: sql.NullBool{Valid: true}, + NullableValue: sql.NullBool{Bool: true, Valid: true}, DefaultValueValue: sql.NullString{String: "true", Valid: true}, UniqueValue: sql.NullBool{Bool: false, Valid: true}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, diff --git a/gormlite/error_translator.go b/gormlite/error_translator.go index ac7a2d2..b076707 100644 --- a/gormlite/error_translator.go +++ b/gormlite/error_translator.go @@ -7,7 +7,7 @@ import ( "gorm.io/gorm" ) -func (dialector Dialector) Translate(err error) error { +func (_Dialector) Translate(err error) error { switch { case errors.Is(err, sqlite3.CONSTRAINT_UNIQUE), diff --git a/gormlite/go.mod b/gormlite/go.mod index 6f81e28..22743fd 100644 --- a/gormlite/go.mod +++ b/gormlite/go.mod @@ -3,8 +3,8 @@ module github.com/ncruces/go-sqlite3/gormlite go 1.21 require ( - github.com/ncruces/go-sqlite3 v0.9.0 - gorm.io/gorm v1.25.4 + github.com/ncruces/go-sqlite3 v0.9.1 + gorm.io/gorm v1.25.5 ) require ( @@ -12,5 +12,5 @@ require ( github.com/jinzhu/now v1.1.5 // indirect github.com/ncruces/julianday v0.1.5 // indirect github.com/tetratelabs/wazero v1.5.0 // indirect - golang.org/x/sys v0.12.0 // indirect + golang.org/x/sys v0.13.0 // indirect ) diff --git a/gormlite/go.sum b/gormlite/go.sum index 0e25637..e5fff6e 100644 --- a/gormlite/go.sum +++ b/gormlite/go.sum @@ -2,15 +2,15 @@ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/ncruces/go-sqlite3 v0.9.0 h1:tl5eEmGEyzZH2ur8sDgPJTdzV4CRnKpsFngoP1QRjD8= -github.com/ncruces/go-sqlite3 v0.9.0/go.mod h1:IyRoNwT0Z+mNRXIVeP2DgWPNl78Kmc/B+pO9i6GNgRg= +github.com/ncruces/go-sqlite3 v0.9.1 h1:kV7Zy+ZNyHMfMyZeWc1Yyq+wtgYZDZdp2qAA/wfeMWo= +github.com/ncruces/go-sqlite3 v0.9.1/go.mod h1:jFoUbaCDNUS1KN5ZgFxN7bgcWoWfO0EOKeik9QAHZ08= github.com/ncruces/julianday v0.1.5 h1:hDJ9ejiMp3DHsoZ5KW4c1lwfMjbARS7u/gbYcd0FBZk= github.com/ncruces/julianday v0.1.5/go.mod h1:Dusn2KvZrrovOMJuOt0TNXL6tB7U2E8kvza5fFc9G7g= github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0= github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= -gorm.io/gorm v1.25.4 h1:iyNd8fNAe8W9dvtlgeRI5zSVZPsq3OpcTu37cYcpCmw= -gorm.io/gorm v1.25.4/go.mod h1:L4uxeKpfBml98NYqVqwAdmV1a2nBtAec/cf3fpucW/k= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/gormlite/migrator.go b/gormlite/migrator.go index 9e2c4c9..95801ab 100644 --- a/gormlite/migrator.go +++ b/gormlite/migrator.go @@ -3,7 +3,6 @@ package gormlite import ( "database/sql" "fmt" - "regexp" "strings" "gorm.io/gorm" @@ -12,11 +11,11 @@ import ( "gorm.io/gorm/schema" ) -type Migrator struct { +type _Migrator struct { migrator.Migrator } -func (m *Migrator) RunWithoutForeignKey(fc func() error) error { +func (m *_Migrator) RunWithoutForeignKey(fc func() error) error { var enabled int m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled) if enabled == 1 { @@ -27,7 +26,7 @@ func (m *Migrator) RunWithoutForeignKey(fc func() error) error { return fc() } -func (m Migrator) HasTable(value interface{}) bool { +func (m _Migrator) HasTable(value interface{}) bool { var count int m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) @@ -35,7 +34,7 @@ func (m Migrator) HasTable(value interface{}) bool { return count > 0 } -func (m Migrator) DropTable(values ...interface{}) error { +func (m _Migrator) DropTable(values ...interface{}) error { return m.RunWithoutForeignKey(func() error { values = m.ReorderModels(values, false) tx := m.DB.Session(&gorm.Session{}) @@ -52,11 +51,11 @@ func (m Migrator) DropTable(values ...interface{}) error { }) } -func (m Migrator) GetTables() (tableList []string, err error) { +func (m _Migrator) GetTables() (tableList []string, err error) { return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error } -func (m Migrator) HasColumn(value interface{}, name string) bool { +func (m _Migrator) HasColumn(value interface{}, name string) bool { var count int m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { @@ -76,31 +75,24 @@ func (m Migrator) HasColumn(value interface{}, name string) bool { return count > 0 } -func (m Migrator) AlterColumn(value interface{}, name string) error { +func (m _Migrator) AlterColumn(value interface{}, name string) error { return m.RunWithoutForeignKey(func() error { - return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { if field := stmt.Schema.LookUpField(name); field != nil { - // lookup field from table definition, ddl might looks like `'name' int,` or `'name' int)` - reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?(,|\\)\\s*$)") - if err != nil { - return "", nil, err + if ddl.alterColumn(field.DBName, fmt.Sprintf("`%s` ?", field.DBName)) { + return nil, nil, fmt.Errorf("field `%s` not found in origin ddl, ddl= '%s'", name, ddl.compile()) } - createSQL := reg.ReplaceAllString(rawDDL, fmt.Sprintf("`%v` ?$3", field.DBName)) - - if createSQL == rawDDL { - return "", nil, fmt.Errorf("failed to look up field %v from DDL %v", field.DBName, rawDDL) - } - - return createSQL, []interface{}{m.FullDataTypeOf(field)}, nil + return ddl, []interface{}{m.FullDataTypeOf(field)}, nil } - return "", nil, fmt.Errorf("failed to alter field with name %v", name) + + return nil, nil, fmt.Errorf("failed to alter field with name `%s`", name) }) }) } // ColumnTypes return columnTypes []gorm.ColumnType and execErr error -func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { +func (m _Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { var ( @@ -148,29 +140,23 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } -func (m Migrator) DropColumn(value interface{}, name string) error { - return m.recreateTable(value, nil, func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { +func (m _Migrator) DropColumn(value interface{}, name string) error { + return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } - reg, err := regexp.Compile("(`|'|\"| |\\[)" + name + "(`|'|\"| |\\]) .*?,") - if err != nil { - return "", nil, err - } - - createSQL := reg.ReplaceAllString(rawDDL, "") - - return createSQL, nil, nil + ddl.removeColumn(name) + return ddl, nil, nil }) } -func (m Migrator) CreateConstraint(value interface{}, name string) error { +func (m _Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) return m.recreateTable(value, &table, - func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { + func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { var ( constraintName string constraintSql string @@ -185,22 +171,16 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { constraintSql = "CONSTRAINT ? CHECK (?)" constraintValues = []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} } else { - return "", nil, nil + return nil, nil, nil } - createDDL, err := parseDDL(rawDDL) - if err != nil { - return "", nil, err - } - createDDL.addConstraint(constraintName, constraintSql) - createSQL := createDDL.compile() - - return createSQL, constraintValues, nil + ddl.addConstraint(constraintName, constraintSql) + return ddl, constraintValues, nil }) }) } -func (m Migrator) DropConstraint(value interface{}, name string) error { +func (m _Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) if constraint != nil { @@ -210,20 +190,14 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { } return m.recreateTable(value, &table, - func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error) { - createDDL, err := parseDDL(rawDDL) - if err != nil { - return "", nil, err - } - createDDL.removeConstraint(name) - createSQL := createDDL.compile() - - return createSQL, nil, nil + func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) { + ddl.removeConstraint(name) + return ddl, nil, nil }) }) } -func (m Migrator) HasConstraint(value interface{}, name string) bool { +func (m _Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, chk, table := m.GuessConstraintAndTable(stmt, name) @@ -244,13 +218,13 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { return count > 0 } -func (m Migrator) CurrentDatabase() (name string) { +func (m _Migrator) CurrentDatabase() (name string) { var null interface{} m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null) return } -func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { +func (m _Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { @@ -269,7 +243,7 @@ func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statem return } -func (m Migrator) CreateIndex(value interface{}, name string) error { +func (m _Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -298,7 +272,7 @@ func (m Migrator) CreateIndex(value interface{}, name string) error { }) } -func (m Migrator) HasIndex(value interface{}, name string) bool { +func (m _Migrator) HasIndex(value interface{}, name string) bool { var count int m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { @@ -317,7 +291,7 @@ func (m Migrator) HasIndex(value interface{}, name string) bool { return count > 0 } -func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { +func (m _Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { var sql string m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) @@ -331,7 +305,7 @@ func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error }) } -func (m Migrator) DropIndex(value interface{}, name string) error { +func (m _Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { @@ -365,7 +339,7 @@ func buildConstraint(constraint *schema.Constraint) (sql string, results []inter return } -func (m Migrator) getRawDDL(table string) (string, error) { +func (m _Migrator) getRawDDL(table string) (string, error) { var createSQL string m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL) @@ -375,8 +349,10 @@ func (m Migrator) getRawDDL(table string) (string, error) { return createSQL, nil } -func (m Migrator) recreateTable(value interface{}, tablePtr *string, - getCreateSQL func(rawDDL string, stmt *gorm.Statement) (sql string, sqlArgs []interface{}, err error)) error { +func (m _Migrator) recreateTable( + value interface{}, tablePtr *string, + getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error), +) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { table := stmt.Table if tablePtr != nil { @@ -388,27 +364,26 @@ func (m Migrator) recreateTable(value interface{}, tablePtr *string, return err } - newTableName := table + "__temp" - - createSQL, sqlArgs, err := getCreateSQL(rawDDL, stmt) + originDDL, err := parseDDL(rawDDL) if err != nil { return err } - if createSQL == "" { + + createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt) + if err != nil { + return err + } + if createDDL == nil { return nil } - tableReg, err := regexp.Compile("\\s*('|`|\")?\\b" + table + "\\b('|`|\")?\\s*") - if err != nil { + newTableName := table + "__temp" + if err := createDDL.renameTable(newTableName, table); err != nil { return err } - createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) - createDDL, err := parseDDL(createSQL) - if err != nil { - return err - } columns := createDDL.getColumns() + createSQL := createDDL.compile() return m.DB.Transaction(func(tx *gorm.DB) error { if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil { diff --git a/gormlite/sqlite.go b/gormlite/sqlite.go index 76433d4..db1eb3c 100644 --- a/gormlite/sqlite.go +++ b/gormlite/sqlite.go @@ -3,6 +3,7 @@ package gormlite import ( "context" + "database/sql" "strconv" "gorm.io/gorm" @@ -15,20 +16,26 @@ import ( "github.com/ncruces/go-sqlite3/driver" ) -type Dialector struct { +// Open opens a GORM dialector from a data source name. +func Open(dsn string) gorm.Dialector { + return &_Dialector{DSN: dsn} +} + +// Open opens a GORM dialector from a database handle. +func OpenDB(db *sql.DB) gorm.Dialector { + return &_Dialector{Conn: db} +} + +type _Dialector struct { DSN string Conn gorm.ConnPool } -func Open(dsn string) gorm.Dialector { - return &Dialector{DSN: dsn} -} - -func (dialector Dialector) Name() string { +func (dialector _Dialector) Name() string { return "sqlite" } -func (dialector Dialector) Initialize(db *gorm.DB) (err error) { +func (dialector _Dialector) Initialize(db *gorm.DB) (err error) { if dialector.Conn != nil { db.ConnPool = dialector.Conn } else { @@ -47,7 +54,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { if compareVersion(version, "3.35.0") >= 0 { callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, - UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, + UpdateClauses: []string{"UPDATE", "SET", "FROM", "WHERE", "RETURNING"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, LastInsertIDReversed: true, }) @@ -63,7 +70,7 @@ func (dialector Dialector) Initialize(db *gorm.DB) (err error) { return } -func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { +func (dialector _Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { return map[string]clause.ClauseBuilder{ "INSERT": func(c clause.Clause, builder clause.Builder) { if insert, ok := c.Expression.(clause.Insert); ok { @@ -112,7 +119,7 @@ func (dialector Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { } } -func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression { +func (dialector _Dialector) DefaultValueOf(field *schema.Field) clause.Expression { if field.AutoIncrement { return clause.Expr{SQL: "NULL"} } @@ -121,19 +128,19 @@ func (dialector Dialector) DefaultValueOf(field *schema.Field) clause.Expression return clause.Expr{SQL: "DEFAULT"} } -func (dialector Dialector) Migrator(db *gorm.DB) gorm.Migrator { - return Migrator{migrator.Migrator{Config: migrator.Config{ +func (dialector _Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return _Migrator{migrator.Migrator{Config: migrator.Config{ DB: db, Dialector: dialector, CreateIndexAfterCreateTable: true, }}} } -func (dialector Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { +func (dialector _Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('?') } -func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { +func (dialector _Dialector) QuoteTo(writer clause.Writer, str string) { var ( underQuoted, selfQuoted bool continuousBacktick int8 @@ -181,16 +188,17 @@ func (dialector Dialector) QuoteTo(writer clause.Writer, str string) { writer.WriteString("`") } -func (dialector Dialector) Explain(sql string, vars ...interface{}) string { +func (dialector _Dialector) Explain(sql string, vars ...interface{}) string { return logger.ExplainSQL(sql, nil, `"`, vars...) } -func (dialector Dialector) DataTypeOf(field *schema.Field) string { +func (dialector _Dialector) DataTypeOf(field *schema.Field) string { switch field.DataType { case schema.Bool: return "numeric" case schema.Int, schema.Uint: - if field.AutoIncrement && !field.PrimaryKey { + if field.AutoIncrement { + // doesn't check `PrimaryKey`, to keep backward compatibility // https://www.sqlite.org/autoinc.html return "integer PRIMARY KEY AUTOINCREMENT" } else { @@ -214,12 +222,12 @@ func (dialector Dialector) DataTypeOf(field *schema.Field) string { return string(field.DataType) } -func (dialectopr Dialector) SavePoint(tx *gorm.DB, name string) error { +func (dialectopr _Dialector) SavePoint(tx *gorm.DB, name string) error { tx.Exec("SAVEPOINT " + name) return nil } -func (dialectopr Dialector) RollbackTo(tx *gorm.DB, name string) error { +func (dialectopr _Dialector) RollbackTo(tx *gorm.DB, name string) error { tx.Exec("ROLLBACK TO SAVEPOINT " + name) return nil } diff --git a/gormlite/sqlite_test.go b/gormlite/sqlite_test.go index f2ab58a..9e724ed 100644 --- a/gormlite/sqlite_test.go +++ b/gormlite/sqlite_test.go @@ -17,7 +17,7 @@ func TestDialector(t *testing.T) { const InMemoryDSN = "file:testdatabase?mode=memory&cache=shared" // Custom connection with a custom function called "my_custom_function". - conn, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error { + db, err := driver.Open(InMemoryDSN, func(ctx context.Context, conn *sqlite3.Conn) error { return conn.CreateFunction("my_custom_function", 0, sqlite3.DETERMINISTIC, func(ctx sqlite3.Context, arg ...sqlite3.Value) { ctx.ResultText("my-result") @@ -29,43 +29,35 @@ func TestDialector(t *testing.T) { rows := []struct { description string - dialector *Dialector + dialector gorm.Dialector openSuccess bool query string querySuccess bool }{ { - description: "Default driver", - dialector: &Dialector{ - DSN: InMemoryDSN, - }, + description: "Default driver", + dialector: Open(InMemoryDSN), openSuccess: true, query: "SELECT 1", querySuccess: true, }, { - description: "Custom function", - dialector: &Dialector{ - DSN: InMemoryDSN, - }, + description: "Custom function", + dialector: Open(InMemoryDSN), openSuccess: true, query: "SELECT my_custom_function()", querySuccess: false, }, { - description: "Custom connection", - dialector: &Dialector{ - Conn: conn, - }, + description: "Custom connection", + dialector: OpenDB(db), openSuccess: true, query: "SELECT 1", querySuccess: true, }, { - description: "Custom connection, custom function", - dialector: &Dialector{ - Conn: conn, - }, + description: "Custom connection, custom function", + dialector: OpenDB(db), openSuccess: true, query: "SELECT my_custom_function()", querySuccess: true,