Skip to content

Commit 35e5ddd

Browse files
authored
Merge pull request #6 from bytebase/circuit-breaker
chore: implement circuit breaker
2 parents 33e16ca + 5c09a59 commit 35e5ddd

4 files changed

Lines changed: 443 additions & 42 deletions

File tree

go/base/context.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,10 @@ type MigrationContext struct {
104104
AzureMySQL bool
105105
AttemptInstantDDL bool
106106

107+
// MaxAuthFailures is the maximum number of authentication failures before aborting
108+
// This prevents retry storms that can trigger firewall rules
109+
MaxAuthFailures int
110+
107111
// SkipPortValidation allows skipping the port validation in `ValidateConnection`
108112
// This is useful when connecting to a MySQL instance where the external port
109113
// may not match the internal port.
@@ -360,6 +364,16 @@ func (this *MigrationContext) GetOldTableName() string {
360364
return getSafeTableName(tableName, "del")
361365
}
362366

367+
// GetGhostDatabaseName returns the database name for ghost/changelog tables
368+
// If GhostDatabaseName is set (for separate schema), use it
369+
// Otherwise, use the same database as the original table
370+
func (this *MigrationContext) GetGhostDatabaseName() string {
371+
if this.GhostDatabaseName != "" {
372+
return this.GhostDatabaseName
373+
}
374+
return this.DatabaseName
375+
}
376+
363377
// GetChangelogTableName generates the name of changelog table, based on original table name
364378
// or a given table name.
365379
func (this *MigrationContext) GetChangelogTableName() string {

go/binlog/gomysql_reader.go

Lines changed: 79 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ package binlog
77

88
import (
99
"fmt"
10+
"strings"
1011
"sync"
1112

1213
"github.com/github/gh-ost/go/base"
1314
"github.com/github/gh-ost/go/mysql"
1415
"github.com/github/gh-ost/go/sql"
16+
"github.com/pkg/errors"
1517

1618
"time"
1719

@@ -28,6 +30,7 @@ type GoMySQLReader struct {
2830
currentCoordinates mysql.BinlogCoordinates
2931
currentCoordinatesMutex *sync.Mutex
3032
LastAppliedRowsEventHint mysql.BinlogCoordinates
33+
authFailureCount int
3134
}
3235

3336
func NewGoMySQLReader(migrationContext *base.MigrationContext) *GoMySQLReader {
@@ -52,6 +55,36 @@ func NewGoMySQLReader(migrationContext *base.MigrationContext) *GoMySQLReader {
5255
}
5356
}
5457

58+
// handleAuthError processes authentication errors and applies circuit breaker logic
59+
func (this *GoMySQLReader) handleAuthError(err error, context string) error {
60+
if err == nil {
61+
// Success case - reset counter if needed
62+
if this.authFailureCount > 0 {
63+
this.migrationContext.Log.Infof("%s successful, resetting auth failure count from %d to 0", context, this.authFailureCount)
64+
this.authFailureCount = 0
65+
}
66+
return nil
67+
}
68+
69+
// Check if this is an authentication error
70+
if !this.isAuthenticationError(err) {
71+
return err // Not an auth error, return as-is
72+
}
73+
74+
// Authentication error - increment counter and check circuit breaker
75+
this.authFailureCount++
76+
77+
if this.migrationContext.MaxAuthFailures > 0 && this.authFailureCount >= this.migrationContext.MaxAuthFailures {
78+
return fmt.Errorf("authentication failed %d times (max: %d) during %s, aborting to prevent firewall blocking: %w",
79+
this.authFailureCount, this.migrationContext.MaxAuthFailures, context, err)
80+
}
81+
82+
this.migrationContext.Log.Errorf("Authentication failure #%d during %s (max: %d): %v",
83+
this.authFailureCount, context, this.migrationContext.MaxAuthFailures, err)
84+
85+
return err
86+
}
87+
5588
// ConnectBinlogStreamer
5689
func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordinates) (err error) {
5790
if coordinates.IsEmpty() {
@@ -66,7 +99,8 @@ func (this *GoMySQLReader) ConnectBinlogStreamer(coordinates mysql.BinlogCoordin
6699
Pos: uint32(this.currentCoordinates.LogPos),
67100
})
68101

69-
return err
102+
// Handle the error (or success) with circuit breaker logic
103+
return this.handleAuthError(err, "connection")
70104
}
71105

72106
func (this *GoMySQLReader) GetCurrentBinlogCoordinates() *mysql.BinlogCoordinates {
@@ -79,7 +113,7 @@ func (this *GoMySQLReader) GetCurrentBinlogCoordinates() *mysql.BinlogCoordinate
79113
// StreamEvents
80114
func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEvent *replication.RowsEvent, entriesChannel chan<- *BinlogEntry) error {
81115
if this.currentCoordinates.IsLogPosOverflowBeyond4Bytes(&this.LastAppliedRowsEventHint) {
82-
return fmt.Errorf("Unexpected rows event at %+v, the binlog end_log_pos is overflow 4 bytes", this.currentCoordinates)
116+
return fmt.Errorf("unexpected rows event at %+v, the binlog end_log_pos is overflow 4 bytes", this.currentCoordinates)
83117
}
84118

85119
if this.currentCoordinates.SmallerThanOrEquals(&this.LastAppliedRowsEventHint) {
@@ -89,7 +123,7 @@ func (this *GoMySQLReader) handleRowsEvent(ev *replication.BinlogEvent, rowsEven
89123

90124
dml := ToEventDML(ev.Header.EventType.String())
91125
if dml == NotDML {
92-
return fmt.Errorf("Unknown DML type: %s", ev.Header.EventType.String())
126+
return fmt.Errorf("unknown DML type: %s", ev.Header.EventType.String())
93127
}
94128
for i, row := range rowsEvent.Rows {
95129
if dml == UpdateDML && i%2 == 1 {
@@ -133,14 +167,16 @@ func (this *GoMySQLReader) StreamEvents(canStopStreaming func() bool, entriesCha
133167
if canStopStreaming() {
134168
return nil
135169
}
136-
for {
137-
if canStopStreaming() {
138-
break
139-
}
170+
for !canStopStreaming() {
140171
ev, err := this.binlogStreamer.GetEvent(context.Background())
141172
if err != nil {
142-
return err
173+
// Handle authentication errors with circuit breaker
174+
return this.handleAuthError(err, "streaming")
143175
}
176+
177+
// Reset counter on successful event (using handleAuthError with nil)
178+
this.handleAuthError(nil, "event retrieval")
179+
144180
func() {
145181
this.currentCoordinatesMutex.Lock()
146182
defer this.currentCoordinatesMutex.Unlock()
@@ -171,3 +207,38 @@ func (this *GoMySQLReader) Close() error {
171207
this.binlogSyncer.Close()
172208
return nil
173209
}
210+
211+
// MySQL error codes for authentication failures
212+
const (
213+
ER_DBACCESS_DENIED_ERROR = 1044 // Access denied for user to database
214+
ER_ACCESS_DENIED_ERROR = 1045 // Access denied for user (using password: YES/NO)
215+
ER_HOST_NOT_ALLOWED = 1130 // Host is not allowed to connect
216+
ER_ACCESS_DENIED_NO_PASSWORD = 1698 // Access denied (no password provided)
217+
ER_ACCOUNT_HAS_BEEN_LOCKED = 3118 // Account has been locked
218+
)
219+
220+
// isAuthenticationError checks if the error is an authentication failure
221+
func (this *GoMySQLReader) isAuthenticationError(err error) bool {
222+
if err == nil {
223+
return false
224+
}
225+
226+
// Check for MySQL protocol errors using proper type assertion
227+
var myErr *gomysql.MyError
228+
if errors.As(err, &myErr) {
229+
switch myErr.Code {
230+
case ER_ACCESS_DENIED_ERROR,
231+
ER_DBACCESS_DENIED_ERROR,
232+
ER_HOST_NOT_ALLOWED,
233+
ER_ACCESS_DENIED_NO_PASSWORD,
234+
ER_ACCOUNT_HAS_BEEN_LOCKED:
235+
return true
236+
}
237+
}
238+
239+
// Fallback: Check error string for compatibility with errors
240+
// that might not be properly typed (e.g., from proxy or older versions)
241+
errStr := strings.ToLower(err.Error())
242+
return strings.Contains(errStr, "access denied") ||
243+
strings.Contains(errStr, "authentication failed")
244+
}

0 commit comments

Comments
 (0)