Skip to content

Commit

Permalink
Merge pull request #50 from cloud-barista/feature/update-mysql
Browse files Browse the repository at this point in the history
update: create DB SQL  for target Provider DB Service
  • Loading branch information
heedaeshin authored Sep 9, 2024
2 parents 8c8ae21 + e439351 commit 39a4990
Show file tree
Hide file tree
Showing 5 changed files with 223 additions and 2,629 deletions.
81 changes: 71 additions & 10 deletions pkg/rdbms/mysql/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,30 @@ import (

// mysqlDBMS struct
type MysqlDBMS struct {
provider models.Provider
db *sql.DB
ctx context.Context
provider models.Provider
db *sql.DB
tartgetProvider models.Provider
ctx context.Context
}

type MysqlDBOption func(*MysqlDBMS)

func (d *MysqlDBMS) GetProvdier() models.Provider {
return d.provider
}

func (d *MysqlDBMS) SetProvdier(provider models.Provider) {
d.provider = provider
}

func (d *MysqlDBMS) GetTargetProvdier() models.Provider {
return d.tartgetProvider
}

func (d *MysqlDBMS) SetTargetProvdier(provider models.Provider) {
d.tartgetProvider = provider
}

func New(provider models.Provider, sqlDB *sql.DB, opts ...MysqlDBOption) *MysqlDBMS {
dms := &MysqlDBMS{
provider: provider,
Expand Down Expand Up @@ -143,6 +160,12 @@ func (d *MysqlDBMS) ShowCreateDBSql(dbName string, dbCreateSql *string) error {
*dbCreateSql = addCollateIfMissing(*dbCreateSql)
*dbCreateSql = EnsureCharsetAndCollate(*dbCreateSql, extractCharacterSet(*dbCreateSql), extractCollation(*dbCreateSql))

// If the target provider is NCP, modify the SQL to use NCP's specific procedure
if d.tartgetProvider == models.NCP {
dbName, charSet, collate := extractDatabaseInfo(*dbCreateSql)
*dbCreateSql = fmt.Sprintf("CALL sys.ncp_create_db('%s', '%s', '%s');", dbName, charSet, collate)
}

return nil
}

Expand All @@ -154,6 +177,8 @@ func (d *MysqlDBMS) ShowCreateTableSql(dbName, tableName string, tableCreateSql
if err := d.db.QueryRow(fmt.Sprintf("SHOW CREATE TABLE %s;", tableName)).Scan(&tableName, tableCreateSql); err != nil {
return err
}
*tableCreateSql = removeSequenceOption(*tableCreateSql)
*tableCreateSql = adjustColumnsToTimestamp(*tableCreateSql)
*tableCreateSql = ReplaceCharsetAndCollate(*tableCreateSql)
return nil
}
Expand Down Expand Up @@ -182,10 +207,10 @@ func (d *MysqlDBMS) GetInsert(dbName, tableName string, insertSql *[]string) err
}
defer selRows.Close()

data := []map[string]string{}
data := []map[string]sql.NullString{}

for selRows.Next() {
values := make([]string, len(columns))
values := make([]sql.NullString, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
Expand All @@ -196,22 +221,30 @@ func (d *MysqlDBMS) GetInsert(dbName, tableName string, insertSql *[]string) err
return err
}

entry := make(map[string]string)
entry := make(map[string]sql.NullString)
for i, column := range columns {
val := values[i]
entry[column] = val
entry[column] = values[i]
}

data = append(data, entry)
}

for _, entry := range data {
values := []string{}
escapedColumns := []string{}
for _, column := range columns {
values = append(values, fmt.Sprintf("'%v'", entry[column]))
escapedColumn := escapeColumnName(column)
escapedColumns = append(escapedColumns, escapedColumn)
val := entry[column]
if val.Valid {
escapedValue := ReplaceEscapeString(val.String)
values = append(values, fmt.Sprintf("'%v'", escapedValue))
} else {
values = append(values, "NULL")
}
}

insertStatement := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", tableName, strings.Join(columns, ", "), strings.Join(values, ", "))
insertStatement := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", tableName, strings.Join(escapedColumns, ", "), strings.Join(values, ", "))
*insertSql = append(*insertSql, insertStatement)
}

Expand Down Expand Up @@ -247,6 +280,24 @@ func ReplaceCharsetAndCollate(sql string) string {
return sql
}

func ReplaceEscapeString(input string) string {
return strings.ReplaceAll(input, "'", "''")
}

func adjustColumnsToTimestamp(sql string) string {
// Use a regular expression to find all columns that use DEFAULT current_timestamp()
re := regexp.MustCompile("`[^`]+`\\s+[^,]+DEFAULT\\s+current_timestamp\\(\\)")

// Replace these columns with TIMESTAMP DEFAULT current_timestamp()
modifiedSQL := re.ReplaceAllStringFunc(sql, func(match string) string {
// Retain the column name and change the rest of the definition to TIMESTAMP
columnName := strings.Split(match, " ")[0] // The first element is the column name
return fmt.Sprintf("%s TIMESTAMP DEFAULT current_timestamp()", columnName)
})

return modifiedSQL
}

// Extract database information
func extractDatabaseInfo(sql string) (string, string, string) {
dbName := extractDatabaseName(sql)
Expand Down Expand Up @@ -284,3 +335,13 @@ func extractCollation(sql string) string {
}
return ""
}

// remove Sequence
func removeSequenceOption(sql string) string {
return strings.Replace(sql, " SEQUENCE=1", "", -1)
}

// escape Reserve Word
func escapeColumnName(columnName string) string {
return fmt.Sprintf("`%s`", columnName)
}
6 changes: 6 additions & 0 deletions service/rdbc/rdbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"strings"

"github.com/cloud-barista/mc-data-manager/models"
"github.com/sirupsen/logrus"
)

Expand All @@ -33,6 +34,10 @@ const (
//
// Configure the interface to make it easier for other DBs to apply in the future
type RDBMS interface {
GetProvdier() models.Provider
SetProvdier(provider models.Provider)
GetTargetProvdier() models.Provider
SetTargetProvdier(provider models.Provider)
Exec(query string) error
ListDB(dst *[]string) error
DeleteDB(dbName string) error
Expand Down Expand Up @@ -131,6 +136,7 @@ func (rdb *RDBController) Copy(dst *RDBController) error {

for _, db := range dbList {
sql = ""
rdb.client.SetTargetProvdier(dst.client.GetProvdier())
if err := rdb.Get(db, &sql); err != nil {
rdb.logWrite("Error", "Get error", err)
return err
Expand Down
Loading

0 comments on commit 39a4990

Please sign in to comment.