diff --git a/go/base/context.go b/go/base/context.go index 3cd2ce611..dc7d17b2f 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -104,6 +104,10 @@ type MigrationContext struct { AzureMySQL bool AttemptInstantDDL bool + // MaxAuthFailures is the maximum number of authentication failures before aborting + // This prevents retry storms that can trigger firewall rules + MaxAuthFailures int + // SkipPortValidation allows skipping the port validation in `ValidateConnection` // This is useful when connecting to a MySQL instance where the external port // may not match the internal port. @@ -360,6 +364,16 @@ func (this *MigrationContext) GetOldTableName() string { return getSafeTableName(tableName, "del") } +// GetGhostDatabaseName returns the database name for ghost/changelog tables +// If GhostDatabaseName is set (for separate schema), use it +// Otherwise, use the same database as the original table +func (this *MigrationContext) GetGhostDatabaseName() string { + if this.GhostDatabaseName != "" { + return this.GhostDatabaseName + } + return this.DatabaseName +} + // GetChangelogTableName generates the name of changelog table, based on original table name // or a given table name. func (this *MigrationContext) GetChangelogTableName() string { diff --git a/go/binlog/gomysql_reader.go b/go/binlog/gomysql_reader.go index d42ba1f30..e2705f05c 100644 --- a/go/binlog/gomysql_reader.go +++ b/go/binlog/gomysql_reader.go @@ -7,11 +7,13 @@ package binlog import ( "fmt" + "strings" "sync" "github.com/github/gh-ost/go/base" "github.com/github/gh-ost/go/mysql" "github.com/github/gh-ost/go/sql" + "github.com/pkg/errors" "time" @@ -28,6 +30,7 @@ type GoMySQLReader struct { currentCoordinates mysql.BinlogCoordinates currentCoordinatesMutex *sync.Mutex LastAppliedRowsEventHint mysql.BinlogCoordinates + authFailureCount int } func NewGoMySQLReader(migrationContext *base.MigrationContext) *GoMySQLReader { @@ -52,6 +55,36 @@ func NewGoMySQLReader(migrationContext *base.MigrationContext) *GoMySQLReader { } } +// handleAuthError processes authentication errors and applies circuit breaker logic +func (this *GoMySQLReader) handleAuthError(err error, context string) error { + if err == nil { + // Success case - reset counter if needed + if this.authFailureCount > 0 { + this.migrationContext.Log.Infof("%s successful, resetting auth failure count from %d to 0", context, this.authFailureCount) + this.authFailureCount = 0 + } + return nil + } + + // Check if this is an authentication error + if !this.isAuthenticationError(err) { + return err // Not an auth error, return as-is + } + + // Authentication error - increment counter and check circuit breaker + this.authFailureCount++ + + if this.migrationContext.MaxAuthFailures > 0 && this.authFailureCount >= this.migrationContext.MaxAuthFailures { + return fmt.Errorf("authentication failed %d times (max: %d) during %s, aborting to prevent firewall blocking: %w", + this.authFailureCount, this.migrationContext.MaxAuthFailures, context, err) + } + + this.migrationContext.Log.Errorf("Authentication failure #%d during %s (max: %d): %v", + this.authFailureCount, context, this.migrationContext.MaxAuthFailures, err) + + return err +} + // ConnectBinlogStreamer func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordinates) (err error) { if coordinates.IsEmpty() { @@ -66,7 +99,8 @@ func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordin Pos: uint32(this.currentCoordinates.LogPos), }) - return err + // Handle the error (or success) with circuit breaker logic + return this.handleAuthError(err, "connection") } func (this *GoMySQLReader) GetCurrentBinlogCoordinates() *mysql.BinlogCoordinates { @@ -79,7 +113,7 @@ func (this *GoMySQLReader) GetCurrentBinlogCoordinates() *mysql.BinlogCoordinate // StreamEvents func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEvent *replication.RowsEvent, entriesChannel chan<- *BinlogEntry) error { if this.currentCoordinates.IsLogPosOverflowBeyond4Bytes(&this.LastAppliedRowsEventHint) { - return fmt.Errorf("Unexpected rows event at %+v, the binlog end_log_pos is overflow 4 bytes", this.currentCoordinates) + return fmt.Errorf("unexpected rows event at %+v, the binlog end_log_pos is overflow 4 bytes", this.currentCoordinates) } if this.currentCoordinates.SmallerThanOrEquals(&this.LastAppliedRowsEventHint) { @@ -89,7 +123,7 @@ func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEven dml := ToEventDML(ev.Header.EventType.String()) if dml == NotDML { - return fmt.Errorf("Unknown DML type: %s", ev.Header.EventType.String()) + return fmt.Errorf("unknown DML type: %s", ev.Header.EventType.String()) } for i, row := range rowsEvent.Rows { if dml == UpdateDML && i%2 == 1 { @@ -133,14 +167,16 @@ func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesCha if canStopStreaming() { return nil } - for { - if canStopStreaming() { - break - } + for !canStopStreaming() { ev, err := this.binlogStreamer.GetEvent(context.Background()) if err != nil { - return err + // Handle authentication errors with circuit breaker + return this.handleAuthError(err, "streaming") } + + // Reset counter on successful event (using handleAuthError with nil) + this.handleAuthError(nil, "event retrieval") + func() { this.currentCoordinatesMutex.Lock() defer this.currentCoordinatesMutex.Unlock() @@ -171,3 +207,38 @@ func (this *GoMySQLReader) Close() error { this.binlogSyncer.Close() return nil } + +// MySQL error codes for authentication failures +const ( + ER_DBACCESS_DENIED_ERROR = 1044 // Access denied for user to database + ER_ACCESS_DENIED_ERROR = 1045 // Access denied for user (using password: YES/NO) + ER_HOST_NOT_ALLOWED = 1130 // Host is not allowed to connect + ER_ACCESS_DENIED_NO_PASSWORD = 1698 // Access denied (no password provided) + ER_ACCOUNT_HAS_BEEN_LOCKED = 3118 // Account has been locked +) + +// isAuthenticationError checks if the error is an authentication failure +func (this *GoMySQLReader) isAuthenticationError(err error) bool { + if err == nil { + return false + } + + // Check for MySQL protocol errors using proper type assertion + var myErr *gomysql.MyError + if errors.As(err, &myErr) { + switch myErr.Code { + case ER_ACCESS_DENIED_ERROR, + ER_DBACCESS_DENIED_ERROR, + ER_HOST_NOT_ALLOWED, + ER_ACCESS_DENIED_NO_PASSWORD, + ER_ACCOUNT_HAS_BEEN_LOCKED: + return true + } + } + + // Fallback: Check error string for compatibility with errors + // that might not be properly typed (e.g., from proxy or older versions) + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "access denied") || + strings.Contains(errStr, "authentication failed") +} diff --git a/go/binlog/gomysql_reader_test.go b/go/binlog/gomysql_reader_test.go new file mode 100644 index 000000000..9e5881810 --- /dev/null +++ b/go/binlog/gomysql_reader_test.go @@ -0,0 +1,313 @@ +package binlog + +import ( + "errors" + "fmt" + "testing" + + "github.com/github/gh-ost/go/base" + "github.com/github/gh-ost/go/mysql" + gomysql "github.com/go-mysql-org/go-mysql/mysql" + "github.com/stretchr/testify/require" +) + +func TestIsAuthenticationError(t *testing.T) { + migrationContext := base.NewMigrationContext() + reader := &GoMySQLReader{ + migrationContext: migrationContext, + } + + tests := []struct { + name string + err error + expected bool + }{ + { + name: "nil error", + err: nil, + expected: false, + }, + { + name: "MySQL 1045 typed error", + err: &gomysql.MyError{Code: 1045, Message: "Access denied for user 'bytebase'@'10.20.5.203'"}, + expected: true, + }, + { + name: "MySQL 1130 typed error", + err: &gomysql.MyError{Code: 1130, Message: "Host '10.20.5.203' is not allowed to connect to this MySQL server"}, + expected: true, + }, + { + name: "MySQL 1044 typed error", + err: &gomysql.MyError{Code: 1044, Message: "Access denied for user 'bytebase'@'%' to database 'mysql'"}, + expected: true, + }, + { + name: "MySQL 1698 typed error", + err: &gomysql.MyError{Code: 1698, Message: "Access denied for user 'root'@'localhost'"}, + expected: true, + }, + { + name: "MySQL 3118 account locked", + err: &gomysql.MyError{Code: 3118, Message: "Account has been locked"}, + expected: true, + }, + { + name: "Wrapped MySQL error", + err: fmt.Errorf("connection failed: %w", &gomysql.MyError{Code: 1045, Message: "Access denied"}), + expected: true, + }, + { + name: "String fallback - access denied", + err: errors.New("Access denied for user attempting to connect"), + expected: true, + }, + { + name: "String fallback - authentication failed", + err: errors.New("authentication failed for user"), + expected: true, + }, + { + name: "Non-auth MySQL error", + err: &gomysql.MyError{Code: 1146, Message: "Table doesn't exist"}, + expected: false, + }, + { + name: "unrelated error", + err: errors.New("connection timeout"), + expected: false, + }, + { + name: "network error", + err: errors.New("dial tcp: connection refused"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := reader.isAuthenticationError(tt.err) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestAuthFailureCircuitBreaker(t *testing.T) { + tests := []struct { + name string + maxAuthFailures int + authFailures int + expectError bool + }{ + { + name: "no limit set", + maxAuthFailures: 0, + authFailures: 100, + expectError: false, // No limit means no circuit breaker + }, + { + name: "under limit", + maxAuthFailures: 5, + authFailures: 3, + expectError: false, + }, + { + name: "at limit", + maxAuthFailures: 5, + authFailures: 5, + expectError: true, + }, + { + name: "over limit", + maxAuthFailures: 5, + authFailures: 10, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.MaxAuthFailures = tt.maxAuthFailures + + connectionConfig := &mysql.ConnectionConfig{ + Key: mysql.InstanceKey{ + Hostname: "test-host", + Port: 3306, + }, + User: "test-user", + Password: "test-password", + } + + reader := &GoMySQLReader{ + migrationContext: migrationContext, + connectionConfig: connectionConfig, + authFailureCount: tt.authFailures - 1, // Simulate previous failures + } + + // Simulate an authentication error + authErr := errors.New("ERROR 1045 (28000): Access denied for user") + + // Check if circuit breaker triggers + if reader.isAuthenticationError(authErr) { + reader.authFailureCount++ + if reader.migrationContext.MaxAuthFailures > 0 && reader.authFailureCount >= reader.migrationContext.MaxAuthFailures { + if !tt.expectError { + t.Errorf("Expected no error but circuit breaker triggered at %d failures", reader.authFailureCount) + } + } else { + if tt.expectError { + t.Errorf("Expected circuit breaker to trigger at %d failures but it didn't", reader.authFailureCount) + } + } + } + }) + } +} + +func TestAuthFailureCounterIncrement(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.MaxAuthFailures = 10 + + reader := &GoMySQLReader{ + migrationContext: migrationContext, + authFailureCount: 0, + } + + // Test that counter increments only for auth errors + testCases := []struct { + err error + shouldCount bool + description string + }{ + {&gomysql.MyError{Code: 1045, Message: "Access denied"}, true, "MySQL 1045 error"}, + {errors.New("connection timeout"), false, "Non-auth error"}, + {&gomysql.MyError{Code: 1130, Message: "Host not allowed"}, true, "MySQL 1130 error"}, + {errors.New("syntax error"), false, "SQL syntax error"}, + {nil, false, "Nil error"}, + {errors.New("access denied for user"), true, "String fallback auth error"}, + } + + for _, tc := range testCases { + initialCount := reader.authFailureCount + + // For nil errors, handleAuthError would reset the counter + // So we test isAuthenticationError directly for nil + if tc.err == nil { + if reader.isAuthenticationError(tc.err) { + t.Errorf("%s: nil should not be detected as auth error", tc.description) + } + continue + } + + // Use handleAuthError which manages the counter + reader.handleAuthError(tc.err, "test") + + if tc.shouldCount { + if reader.authFailureCount != initialCount+1 { + t.Errorf("%s: Counter did not increment for auth error: %v", tc.description, tc.err) + } + } else { + if reader.authFailureCount != initialCount { + t.Errorf("%s: Counter incorrectly incremented for non-auth error: %v", tc.description, tc.err) + } + } + } + + // Test that successful operation resets counter + reader.authFailureCount = 5 + reader.handleAuthError(nil, "test success") + if reader.authFailureCount != 0 { + t.Errorf("Counter not reset on success, got %d", reader.authFailureCount) + } +} + +func TestAuthFailureCounterReset(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.MaxAuthFailures = 10 + + reader := &GoMySQLReader{ + migrationContext: migrationContext, + authFailureCount: 0, + } + + // Simulate auth failures + authError := errors.New("ERROR 1045: Access denied") + for i := 0; i < 3; i++ { + if reader.isAuthenticationError(authError) { + reader.authFailureCount++ + } + } + + if reader.authFailureCount != 3 { + t.Errorf("Expected auth failure count 3, got %d", reader.authFailureCount) + } + + // Simulate successful connection - should reset counter + // In real code, this happens in ConnectBinlogStreamer on success + reader.authFailureCount = 0 + + if reader.authFailureCount != 0 { + t.Errorf("Expected auth failure count to be reset to 0, got %d", reader.authFailureCount) + } + + // Simulate more failures after reset + for i := 0; i < 2; i++ { + if reader.isAuthenticationError(authError) { + reader.authFailureCount++ + } + } + + if reader.authFailureCount != 2 { + t.Errorf("Expected auth failure count 2 after reset, got %d", reader.authFailureCount) + } +} + +func TestAuthFailureRecoveryScenario(t *testing.T) { + // Test a realistic scenario: + // 1. Some auth failures + // 2. Successful connection (counter reset) + // 3. More auth failures + // 4. Should only trigger circuit breaker based on consecutive failures + + migrationContext := base.NewMigrationContext() + migrationContext.MaxAuthFailures = 5 + + reader := &GoMySQLReader{ + migrationContext: migrationContext, + authFailureCount: 0, + } + + authError := errors.New("ERROR 1045: Access denied") + + // First round: 3 failures + for i := 0; i < 3; i++ { + if reader.isAuthenticationError(authError) { + reader.authFailureCount++ + } + } + require.Equal(t, 3, reader.authFailureCount, "Should have 3 failures") + + // Successful connection - reset + reader.authFailureCount = 0 + require.Equal(t, 0, reader.authFailureCount, "Should reset to 0 after success") + + // Second round: 4 more failures (under limit) + for i := 0; i < 4; i++ { + if reader.isAuthenticationError(authError) { + reader.authFailureCount++ + } + } + require.Equal(t, 4, reader.authFailureCount, "Should have 4 failures after reset") + + // Circuit breaker should not trigger yet (4 < 5) + shouldTrigger := reader.migrationContext.MaxAuthFailures > 0 && + reader.authFailureCount >= reader.migrationContext.MaxAuthFailures + require.False(t, shouldTrigger, "Circuit breaker should not trigger at 4 failures with limit 5") + + // One more failure should trigger + reader.authFailureCount++ + shouldTrigger = reader.migrationContext.MaxAuthFailures > 0 && + reader.authFailureCount >= reader.migrationContext.MaxAuthFailures + require.True(t, shouldTrigger, "Circuit breaker should trigger at 5 failures with limit 5") +} diff --git a/go/logic/applier.go b/go/logic/applier.go index d15122007..46e5fc700 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -116,7 +116,7 @@ func (this *Applier) InitDBConnections() (err error) { func (this *Applier) prepareQueries() (err error) { if this.dmlDeleteQueryBuilder, err = sql.NewDMLDeleteQueryBuilder( - this.migrationContext.GhostDatabaseName, + this.migrationContext.GetGhostDatabaseName(), this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, &this.migrationContext.UniqueKey.Columns, @@ -124,7 +124,7 @@ func (this *Applier) prepareQueries() (err error) { return err } if this.dmlInsertQueryBuilder, err = sql.NewDMLInsertQueryBuilder( - this.migrationContext.GhostDatabaseName, + this.migrationContext.GetGhostDatabaseName(), this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, @@ -133,7 +133,7 @@ func (this *Applier) prepareQueries() (err error) { return err } if this.dmlUpdateQueryBuilder, err = sql.NewDMLUpdateQueryBuilder( - this.migrationContext.GhostDatabaseName, + this.migrationContext.GetGhostDatabaseName(), this.migrationContext.GetGhostTableName(), this.migrationContext.OriginalTableColumns, this.migrationContext.SharedColumns, @@ -220,8 +220,8 @@ func (this *Applier) ValidateOrDropExistingTables() error { return err } } - if this.tableExists(this.migrationContext.GhostDatabaseName, this.migrationContext.GetGhostTableName()) { - return fmt.Errorf("Table %s.%s already exists. Panicking. Use --initially-drop-ghost-table to force dropping it, though I really prefer that you drop it or rename it away", sql.EscapeName(this.migrationContext.GhostDatabaseName), sql.EscapeName(this.migrationContext.GetGhostTableName())) + if this.tableExists(this.migrationContext.GetGhostDatabaseName(), this.migrationContext.GetGhostTableName()) { + return fmt.Errorf("Table %s.%s already exists. Panicking. Use --initially-drop-ghost-table to force dropping it, though I really prefer that you drop it or rename it away", sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName())) } if this.migrationContext.InitiallyDropOldTable { if err := this.DropOldTable(); err != nil { @@ -232,8 +232,8 @@ func (this *Applier) ValidateOrDropExistingTables() error { this.migrationContext.Log.Fatalf("--timestamp-old-table defined, but resulting table name (%s) is too long (only %d characters allowed)", this.migrationContext.GetOldTableName(), mysql.MaxTableNameLength) } - if this.tableExists(this.migrationContext.GhostDatabaseName, this.migrationContext.GetOldTableName()) { - return fmt.Errorf("Table %s.%s already exists. Panicking. Use --initially-drop-old-table to force dropping it, though I really prefer that you drop it or rename it away", sql.EscapeName(this.migrationContext.GhostDatabaseName), sql.EscapeName(this.migrationContext.GetOldTableName())) + if this.tableExists(this.migrationContext.GetGhostDatabaseName(), this.migrationContext.GetOldTableName()) { + return fmt.Errorf("Table %s.%s already exists. Panicking. Use --initially-drop-old-table to force dropping it, though I really prefer that you drop it or rename it away", sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName())) } return nil @@ -272,13 +272,13 @@ func (this *Applier) AttemptInstantDDL() error { // CreateGhostTable creates the ghost table on the applier host func (this *Applier) CreateGhostTable() error { query := fmt.Sprintf(`create /* gh-ost */ table %s.%s like %s.%s`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), ) this.migrationContext.Log.Infof("Creating ghost table %s.%s", - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), ) @@ -313,12 +313,12 @@ func (this *Applier) CreateGhostTable() error { // AlterGhost applies `alter` statement on ghost table func (this *Applier) AlterGhost() error { query := fmt.Sprintf(`alter /* gh-ost */ table %s.%s %s`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), this.migrationContext.AlterStatementOptions, ) this.migrationContext.Log.Infof("Altering ghost table %s.%s", - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), ) this.migrationContext.Log.Debugf("ALTER statement: %s", query) @@ -354,12 +354,12 @@ func (this *Applier) AlterGhost() error { // AlterGhost applies `alter` statement on ghost table func (this *Applier) AlterGhostAutoIncrement() error { query := fmt.Sprintf(`alter /* gh-ost */ table %s.%s AUTO_INCREMENT=%d`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), this.migrationContext.OriginalTableAutoIncrement, ) this.migrationContext.Log.Infof("Altering ghost table AUTO_INCREMENT value %s.%s", - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), ) this.migrationContext.Log.Debugf("AUTO_INCREMENT ALTER statement: %s", query) @@ -383,12 +383,12 @@ func (this *Applier) CreateChangelogTable() error { primary key(id), unique key hint_uidx(hint) ) auto_increment=256 comment='%s'`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetChangelogTableName()), GhostChangelogTableComment, ) this.migrationContext.Log.Infof("Creating changelog table %s.%s", - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { @@ -400,12 +400,15 @@ func (this *Applier) CreateChangelogTable() error { // dropTable drops a given table on the applied host func (this *Applier) dropTable(tableName string) error { + // Use the helper method to get database name with fallback + databaseName := this.migrationContext.GetGhostDatabaseName() + query := fmt.Sprintf(`drop /* gh-ost */ table if exists %s.%s`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(databaseName), sql.EscapeName(tableName), ) this.migrationContext.Log.Infof("Dropping table %s.%s", - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(databaseName), sql.EscapeName(tableName), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { @@ -452,7 +455,7 @@ func (this *Applier) WriteChangelog(hint, value string) (string, error) { on duplicate key update last_update=NOW(), value=VALUES(value)`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetChangelogTableName()), ) _, err := sqlutils.ExecNoPrepare(this.db, query, explicitId, hint, value) @@ -670,7 +673,7 @@ func (this *Applier) ApplyIterationInsertQuery() (chunkSize int64, rowsAffected query, explodedArgs, err := sql.BuildRangeInsertPreparedQuery( this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, - this.migrationContext.GhostDatabaseName, + this.migrationContext.GetGhostDatabaseName(), this.migrationContext.GetGhostTableName(), this.migrationContext.SharedColumns.Names(), this.migrationContext.MappedSharedColumns.Names(), @@ -761,7 +764,7 @@ func (this *Applier) SwapTablesQuickAndBumpy() error { query := fmt.Sprintf(`alter /* gh-ost */ table %s.%s rename %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), ) this.migrationContext.Log.Infof("Renaming original table") @@ -770,7 +773,7 @@ func (this *Applier) SwapTablesQuickAndBumpy() error { return err } query = fmt.Sprintf(`alter /* gh-ost */ table %s.%s rename%s. %s`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), @@ -793,9 +796,9 @@ func (this *Applier) RenameTablesRollback() (renameError error) { query := fmt.Sprintf(`rename /* gh-ost */ table %s.%s to %s.%s, %s.%s to %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), @@ -808,7 +811,7 @@ func (this *Applier) RenameTablesRollback() (renameError error) { query = fmt.Sprintf(`rename /* gh-ost */ table %s.%s to %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), ) this.migrationContext.Log.Infof("Renaming back to ghost table") @@ -816,7 +819,7 @@ func (this *Applier) RenameTablesRollback() (renameError error) { renameError = err } query = fmt.Sprintf(`rename /* gh-ost */ table %s.%s to %s.%s`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), @@ -954,7 +957,7 @@ func (this *Applier) ExpectProcess(sessionId int64, stateHint, infoHint string) func (this *Applier) DropAtomicCutOverSentryTableIfExists() error { this.migrationContext.Log.Infof("Looking for magic cut-over table") tableName := this.migrationContext.GetOldTableName() - rowMap := this.showTableStatus(this.migrationContext.GhostDatabaseName, tableName) + rowMap := this.showTableStatus(this.migrationContext.GetGhostDatabaseName(), tableName) if rowMap == nil { // Table does not exist return nil @@ -977,13 +980,13 @@ func (this *Applier) CreateAtomicCutOverSentryTable() error { create /* gh-ost */ table %s.%s ( id int auto_increment primary key ) engine=%s comment='%s'`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(tableName), this.migrationContext.TableEngine, atomicCutOverMagicHint, ) this.migrationContext.Log.Infof("Creating magic cut-over table %s.%s", - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(tableName), ) if _, err := sqlutils.ExecNoPrepare(this.db, query); err != nil { @@ -1069,13 +1072,13 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke query = fmt.Sprintf(`lock /* gh-ost */ tables %s.%s write, %s.%s write`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), ) this.migrationContext.Log.Infof("Locking %s.%s, %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), ) this.migrationContext.LockTablesStartTime = time.Now() @@ -1098,7 +1101,7 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke // And in fact, we will: this.migrationContext.Log.Infof("Dropping magic cut-over table") query = fmt.Sprintf(`drop /* gh-ost */ table if exists %s.%s`, - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), ) @@ -1111,7 +1114,7 @@ func (this *Applier) AtomicCutOverMagicLock(sessionIdChan chan int64, tableLocke this.migrationContext.Log.Infof("Releasing lock from %s.%s, %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), ) query = `unlock /* gh-ost */ tables` @@ -1150,9 +1153,9 @@ func (this *Applier) AtomicCutoverRename(sessionIdChan chan int64, tablesRenamed query = fmt.Sprintf(`rename /* gh-ost */ table %s.%s to %s.%s, %s.%s to %s.%s`, sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetOldTableName()), - sql.EscapeName(this.migrationContext.GhostDatabaseName), + sql.EscapeName(this.migrationContext.GetGhostDatabaseName()), sql.EscapeName(this.migrationContext.GetGhostTableName()), sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName),