Платформа ЦРНП "Мирокод" для разработки проектов
https://git.mirocod.ru
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
2.3 KiB
110 lines
2.3 KiB
package testfixtures |
|
|
|
import ( |
|
"database/sql" |
|
"fmt" |
|
) |
|
|
|
// SQLServer is the helper for SQL Server for this package. |
|
// SQL Server >= 2008 is required. |
|
type SQLServer struct { |
|
baseHelper |
|
|
|
tables []string |
|
} |
|
|
|
func (h *SQLServer) init(db *sql.DB) error { |
|
var err error |
|
|
|
h.tables, err = h.getTables(db) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
return nil |
|
} |
|
|
|
func (*SQLServer) paramType() int { |
|
return paramTypeQuestion |
|
} |
|
|
|
func (*SQLServer) quoteKeyword(str string) string { |
|
return fmt.Sprintf("[%s]", str) |
|
} |
|
|
|
func (*SQLServer) databaseName(db *sql.DB) (dbname string) { |
|
db.QueryRow("SELECT DB_NAME()").Scan(&dbname) |
|
return |
|
} |
|
|
|
func (*SQLServer) getTables(db *sql.DB) ([]string, error) { |
|
rows, err := db.Query("SELECT table_name FROM information_schema.tables") |
|
if err != nil { |
|
return nil, err |
|
} |
|
|
|
tables := make([]string, 0) |
|
defer rows.Close() |
|
for rows.Next() { |
|
var table string |
|
rows.Scan(&table) |
|
tables = append(tables, table) |
|
} |
|
return tables, nil |
|
} |
|
|
|
func (*SQLServer) tableHasIdentityColumn(tx *sql.Tx, tableName string) bool { |
|
sql := ` |
|
SELECT COUNT(*) |
|
FROM SYS.IDENTITY_COLUMNS |
|
WHERE OBJECT_NAME(OBJECT_ID) = ? |
|
` |
|
var count int |
|
tx.QueryRow(sql, tableName).Scan(&count) |
|
return count > 0 |
|
|
|
} |
|
|
|
func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) error { |
|
if h.tableHasIdentityColumn(tx, tableName) { |
|
defer tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName))) |
|
_, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName))) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
return fn() |
|
} |
|
|
|
func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error { |
|
// ensure the triggers are re-enable after all |
|
defer func() { |
|
sql := "" |
|
for _, table := range h.tables { |
|
sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table)) |
|
} |
|
if _, err := db.Exec(sql); err != nil { |
|
fmt.Printf("Error on re-enabling constraints: %v\n", err) |
|
} |
|
}() |
|
|
|
sql := "" |
|
for _, table := range h.tables { |
|
sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table)) |
|
} |
|
if _, err := db.Exec(sql); err != nil { |
|
return err |
|
} |
|
|
|
tx, err := db.Begin() |
|
if err != nil { |
|
return err |
|
} |
|
|
|
if err = loadFn(tx); err != nil { |
|
tx.Rollback() |
|
return err |
|
} |
|
|
|
return tx.Commit() |
|
}
|
|
|