From 01b85119ab7827cc391217a4e148147bcae54bec Mon Sep 17 00:00:00 2001 From: Nabendu Maiti Date: Thu, 19 Mar 2026 09:35:27 +0000 Subject: [PATCH 1/2] fix: optimize LME activation sequence and local transport time optimize lme accss Signed-off-by: Nabendu Maiti --- internal/commands/activate/activate.go | 29 +++++++-- internal/commands/activate/activate_test.go | 12 ++++ internal/commands/activate/local.go | 37 ++++++++++- internal/lm/engine.go | 23 ++++++- internal/local/amt/localTransport.go | 65 +++++++++++++++++-- internal/local/amt/wsman.go | 38 +++++++---- internal/rps/executor.go | 72 +++++++++++++++++++-- internal/rps/rps.go | 1 + pkg/utils/constants.go | 12 ++-- 9 files changed, 248 insertions(+), 41 deletions(-) diff --git a/internal/commands/activate/activate.go b/internal/commands/activate/activate.go index cd0d825e..d5d1d457 100644 --- a/internal/commands/activate/activate.go +++ b/internal/commands/activate/activate.go @@ -73,6 +73,8 @@ func (cmd *ActivateCmd) Validate() error { log.Trace("Entering Validate method of ActivateCmd") // Determine if caller intends local activation (explicit --local or local-only flags) + cmd.URL = normalizeActivateURL(cmd.URL) + localIntent := cmd.Local || cmd.hasLocalActivationFlags() // Resolve local-vs-remote precedence when both are present. @@ -81,7 +83,6 @@ func (cmd *ActivateCmd) Validate() error { if localIntent && cmd.URL != "" { lowerURL := strings.ToLower(cmd.URL) if strings.HasPrefix(lowerURL, "http://") || strings.HasPrefix(lowerURL, "https://") { - log.Warn("Both --url and local activation flags detected; proceeding with local activation via http://") // Clear URL so we don't trigger HTTP profile fullflow during local runs (prevents recursion) cmd.URL = "" } @@ -201,6 +202,25 @@ func (cmd *ActivateCmd) hasLocalActivationFlags() bool { cmd.ProvisioningCert != "" || cmd.ProvisioningCertPwd != "" || cmd.SkipIPRenew } +func normalizeActivateURL(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + + u, err := url.Parse(value) + if err != nil { + return value + } + + scheme := strings.ToLower(u.Scheme) + if (scheme == "http" || scheme == "https" || scheme == "ws" || scheme == "wss") && u.Host == "" { + return "" + } + + return value +} + // Run executes the activate command based on detected mode func (cmd *ActivateCmd) Run(ctx *commands.Context) error { log.Tracef("Entering Run method of ActivateCmd. Context: %s", ctx.AuthEndpoint) @@ -210,10 +230,9 @@ func (cmd *ActivateCmd) Run(ctx *commands.Context) error { if err := cmd.EnsureAMTPassword(ctx, cmd); err != nil { return err } - - if err := cmd.EnsureWSMAN(ctx); err != nil { - return err - } + // Do not pre-create WSMAN for local activation here. + // LocalActivateCmd sets up its own local WSMAN transport, and doing both + // can trigger an extra LME/APF initialize cycle. } // Determine activation mode based on flags if cmd.URL != "" { diff --git a/internal/commands/activate/activate_test.go b/internal/commands/activate/activate_test.go index 048dbeb3..52ed0ef5 100644 --- a/internal/commands/activate/activate_test.go +++ b/internal/commands/activate/activate_test.go @@ -416,6 +416,18 @@ func TestActivateCmd_Validate_PrecendenceMatrix(t *testing.T) { wantErr: false, wantClearedURL: true, }, + { + name: "Local CCM with placeholder HTTP URL clears URL", + cmd: ActivateCmd{Local: true, CCM: true, URL: "http://"}, + wantErr: false, + wantClearedURL: true, + }, + { + name: "Local CCM with spaced placeholder HTTPS URL clears URL", + cmd: ActivateCmd{Local: true, CCM: true, URL: " https:// "}, + wantErr: false, + wantClearedURL: true, + }, { name: "HTTP URL remote only (no local flags) retains URL", cmd: ActivateCmd{URL: "https://server/p3"}, diff --git a/internal/commands/activate/local.go b/internal/commands/activate/local.go index 3295a099..7425af1e 100644 --- a/internal/commands/activate/local.go +++ b/internal/commands/activate/local.go @@ -23,7 +23,9 @@ import ( "errors" "fmt" "strings" + "time" + "github.com/device-management-toolkit/go-wsman-messages/v2/pkg/wsman/amt/general" "github.com/device-management-toolkit/go-wsman-messages/v2/pkg/wsman/client" "github.com/device-management-toolkit/rpc-go/v2/internal/certs" "github.com/device-management-toolkit/rpc-go/v2/internal/commands" @@ -479,7 +481,7 @@ func (service *LocalActivationService) activateCCM() error { // Get general settings for digest realm - generalSettings, err := service.wsman.GetGeneralSettings() + generalSettings, err := service.getGeneralSettingsWithRetry() if err != nil { return utils.ActivationFailedGeneralSettings } @@ -749,7 +751,7 @@ func (service *LocalActivationService) activateACMWithTLS(tlsConfig *tls.Config) // Get general settings to obtain digest realm for password hashing - generalSettings, err := service.wsman.GetGeneralSettings() + generalSettings, err := service.getGeneralSettingsWithRetry() if err != nil { return fmt.Errorf("failed to get AMT general settings: %w", err) } @@ -969,7 +971,7 @@ func (service *LocalActivationService) activateACMLegacy(tlsConfig *tls.Config) // Get general settings for digest realm - generalSettings, err := service.wsman.GetGeneralSettings() + generalSettings, err := service.getGeneralSettingsWithRetry() if err != nil { return utils.ActivationFailedGeneralSettings } @@ -1044,6 +1046,35 @@ func (service *LocalActivationService) handleSetupErrorWithControlModeVerificati return utils.ActivationFailedControlMode } +func (service *LocalActivationService) getGeneralSettingsWithRetry() (general.Response, error) { + const maxRetries = 3 + + var lastErr error + + for attempt := 0; attempt <= maxRetries; attempt++ { + response, err := service.wsman.GetGeneralSettings() + if err == nil { + return response, nil + } + + lastErr = err + errText := strings.ToLower(err.Error()) + transientBusy := strings.Contains(errText, "device or resource busy") || + strings.Contains(errText, "resource busy") || + strings.Contains(errText, "no such device") || + strings.Contains(errText, "device unavailable") + if !transientBusy || attempt == maxRetries { + break + } + + delay := time.Duration(attempt+1) * time.Duration(utils.HeciConnectRetryBackoff) * time.Millisecond + log.WithError(err).Warnf("GetGeneralSettings busy, retrying (%d/%d)", attempt+1, maxRetries) + time.Sleep(delay) + } + + return general.Response{}, lastErr +} + // Certificate handling methods for ACM activation // convertPfxToObject converts a base64 PFX certificate to a CertsAndKeys object diff --git a/internal/lm/engine.go b/internal/lm/engine.go index 216a39a7..4d048c12 100644 --- a/internal/lm/engine.go +++ b/internal/lm/engine.go @@ -19,6 +19,8 @@ import ( log "github.com/sirupsen/logrus" ) +const lmeAPFChannelDataFlushOverride = 500 * time.Millisecond + // LMConnection is struct for managing connection to LMS type LMEConnection struct { Command pthi.Command @@ -179,7 +181,7 @@ func (lme *LMEConnection) execute(bin_buf bytes.Buffer) error { return err } - bin_buf = apf.Process(result, lme.Session) + bin_buf = lme.processWithLocalTimerOverride(result) if bin_buf.Len() == 0 { log.Debug("done EXECUTING.........") @@ -244,7 +246,7 @@ func (lme *LMEConnection) Listen() { break } else { - result := apf.Process(result2, lme.Session) + result := lme.processWithLocalTimerOverride(result2) if result.Len() != 0 { err2 = lme.execute(result) if err2 != nil { @@ -257,6 +259,23 @@ func (lme *LMEConnection) Listen() { } } +func (lme *LMEConnection) processWithLocalTimerOverride(message []byte) bytes.Buffer { + processed := apf.Process(message, lme.Session) + + if len(message) > 0 && message[0] == apf.APF_CHANNEL_DATA && lme.Session.Timer != nil { + if !lme.Session.Timer.Stop() { + select { + case <-lme.Session.Timer.C: + default: + } + } + + lme.Session.Timer.Reset(lmeAPFChannelDataFlushOverride) + } + + return processed +} + // Close closes the LME connection func (lme *LMEConnection) Close() error { diff --git a/internal/local/amt/localTransport.go b/internal/local/amt/localTransport.go index ae007974..0663a553 100644 --- a/internal/local/amt/localTransport.go +++ b/internal/local/amt/localTransport.go @@ -14,9 +14,11 @@ import ( "net/http" "strings" "sync" + "time" "github.com/device-management-toolkit/rpc-go/v2/internal/lm" "github.com/device-management-toolkit/rpc-go/v2/pkg/heci" + "github.com/device-management-toolkit/rpc-go/v2/pkg/utils" "github.com/sirupsen/logrus" ) @@ -28,6 +30,8 @@ type LocalTransport struct { waitGroup *sync.WaitGroup } +const maxChannelOpenBusyRetries = 2 + func NewLocalTransport() *LocalTransport { lmDataChannel := make(chan []byte) lmErrorChannel := make(chan error) @@ -66,18 +70,47 @@ func (l *LocalTransport) Close() error { // Custom dialer function func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { - // send channel open - err := l.local.Connect() + var err error + for attempt := 0; attempt <= maxChannelOpenBusyRetries; attempt++ { + err = l.local.Connect() + if err == nil { + break + } + + if !isMEIDeviceBusyError(err) || attempt == maxChannelOpenBusyRetries { + logrus.Error(err) + + return nil, err + } + + wait := time.Duration(attempt+1) * time.Duration(utils.HeciConnectRetryBackoff) * time.Millisecond + logrus.Warnf("mei busy during channel open, retry %d/%d", attempt+1, maxChannelOpenBusyRetries) + time.Sleep(wait) + } + go l.local.Listen() - if err != nil { - logrus.Error(err) + channelOpenTimeout := time.Duration(utils.LMETimerTimeout) * time.Second + if channelOpenTimeout <= 0 || channelOpenTimeout > utils.AMTResponseTimeout*time.Second { + channelOpenTimeout = utils.AMTResponseTimeout * time.Second + } - return nil, err + channelOpenTimer := time.After(channelOpenTimeout) + + channelOpenDone := make(chan struct{}) + + go func() { + defer close(channelOpenDone) + + l.waitGroup.Wait() + }() + + select { + case <-channelOpenDone: + case <-channelOpenTimer: + return nil, fmt.Errorf("timeout waiting for LME channel open confirmation after %s", channelOpenTimeout) } - // wait for channel open confirmation - l.waitGroup.Wait() logrus.Trace("Channel open confirmation received") // Serialize the HTTP request to raw form rawRequest, err := serializeHTTPRequest(r) @@ -99,6 +132,10 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { return nil, err } + responseTimeout := utils.AMTResponseTimeout * time.Second + + responseTimer := time.After(responseTimeout) + Loop: for { select { @@ -117,6 +154,10 @@ Loop: respErr = errFromLMS } + break Loop + case <-responseTimer: + respErr = fmt.Errorf("timeout waiting for LME response after %s", responseTimeout) + break Loop } } @@ -171,3 +212,13 @@ func serializeHTTPRequest(r *http.Request) ([]byte, error) { return reqBuffer.Bytes(), nil } + +func isMEIDeviceBusyError(err error) bool { + if err == nil { + return false + } + + errMsg := strings.ToLower(err.Error()) + + return strings.Contains(errMsg, "device or resource busy") || strings.Contains(errMsg, "resource busy") +} diff --git a/internal/local/amt/wsman.go b/internal/local/amt/wsman.go index 20488a35..9017f006 100644 --- a/internal/local/amt/wsman.go +++ b/internal/local/amt/wsman.go @@ -50,6 +50,8 @@ type GoWSMANMessages struct { wsmanMessages wsman.Messages target string localTransport *LocalTransport + plainProbeDone bool + useLocalLMX bool } func NewGoWSMANMessages(lmsAddress string) *GoWSMANMessages { @@ -78,10 +80,11 @@ func (g *GoWSMANMessages) SetupWsmanClient(username, password string, useTLS, lo defer cancel() dialer := &cryptotls.Dialer{ - Config: tlsConfig, + Config: tlsConfig, + NetDialer: &net.Dialer{Timeout: probeTimeout}, } - conn, err := dialer.DialContext(ctx, "tcp", utils.LMSAddress+":"+utils.LMSTLSPort) + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort(g.target, utils.LMSTLSPort)) if err != nil { logrus.Info("Failed to connect to LMS. We're probably going to fail now. Sorry!") logrus.Error(err) @@ -97,20 +100,31 @@ func (g *GoWSMANMessages) SetupWsmanClient(username, password string, useTLS, lo defer conn.Close() } } else { - ctx, cancel := context.WithTimeout(context.Background(), probeTimeout) - defer cancel() + if !g.plainProbeDone { + ctx, cancel := context.WithTimeout(context.Background(), probeTimeout) + defer cancel() - dialer := &net.Dialer{} + dialer := &net.Dialer{Timeout: probeTimeout} - con, err := dialer.DialContext(ctx, "tcp4", utils.LMSAddress+":"+utils.LMSPort) - if err != nil { - logrus.Info("LMS not active, using local transport instead.") + con, err := dialer.DialContext(ctx, "tcp4", net.JoinHostPort(g.target, utils.LMSPort)) + if err != nil { + logrus.Info("LMS not active, using local transport instead.") + + g.useLocalLMX = true + } else { + logrus.Info("Successfully connected to LMS.") + con.Close() + } + + g.plainProbeDone = true + } + + if g.useLocalLMX { + if g.localTransport == nil { + g.localTransport = NewLocalTransport() + } - g.localTransport = NewLocalTransport() clientParams.Transport = g.localTransport - } else { - logrus.Info("Successfully connected to LMS.") - con.Close() } } diff --git a/internal/rps/executor.go b/internal/rps/executor.go index 0072be97..65800c15 100644 --- a/internal/rps/executor.go +++ b/internal/rps/executor.go @@ -5,11 +5,11 @@ package rps import ( - "context" "errors" "fmt" "os" "os/signal" + "strings" "sync" "syscall" "time" @@ -170,7 +170,7 @@ func (e *Executor) HandleDataFromRPS(dataFromServer []byte) bool { // LME: Open fresh channel for each request (AMT closes after each response) log.Debug("LME: Opening new APF channel for this request") - err := e.localManagement.Connect() + err := e.connectLMEWithRetry() if err != nil { e.lastError = fmt.Errorf("failed to open LME channel: %w", err) log.Error(err) @@ -189,7 +189,31 @@ func (e *Executor) HandleDataFromRPS(dataFromServer []byte) bool { // Wait for APF channel-open confirmation before sending request data. // This avoids sending APF channel data on a channel AMT has not confirmed yet. - e.waitGroup.Wait() + channelOpenTimeout := time.Duration(utils.LMETimerTimeout) * time.Second + if channelOpenTimeout <= 0 || channelOpenTimeout > utils.AMTResponseTimeout*time.Second { + channelOpenTimeout = utils.AMTResponseTimeout * time.Second + } + + channelOpenTimer := time.After(channelOpenTimeout) + + channelOpenDone := make(chan struct{}) + + go func() { + defer close(channelOpenDone) + + e.waitGroup.Wait() + }() + + select { + case <-channelOpenDone: + case <-channelOpenTimer: + log.Error("Timeout waiting for LME channel open confirmation - AMT not responding") + + e.lastError = fmt.Errorf("timeout waiting for AMT channel open confirmation after %s", channelOpenTimeout) + + return true + } + log.Trace("Channel open confirmation received") } else { // LMS: open/close connection for every request @@ -214,8 +238,8 @@ func (e *Executor) HandleDataFromRPS(dataFromServer []byte) bool { return true } - timeoutCtx, cancel := context.WithTimeout(context.Background(), utils.AMTResponseTimeout*time.Second) - defer cancel() + responseTimeout := utils.AMTResponseTimeout * time.Second + responseTimer := time.After(responseTimeout) for { select { @@ -242,7 +266,7 @@ func (e *Executor) HandleDataFromRPS(dataFromServer []byte) bool { return true } - case <-timeoutCtx.Done(): + case <-responseTimer: // Timeout waiting for response from AMT/LME // This indicates AMT is not responding - treat as an error log.Error("Timeout waiting for LME response - AMT not responding") @@ -254,6 +278,42 @@ func (e *Executor) HandleDataFromRPS(dataFromServer []byte) bool { } } +func (e *Executor) connectLMEWithRetry() error { + const maxRetries = 2 + + var err error + + for attempt := 0; attempt <= maxRetries; attempt++ { + err = e.localManagement.Connect() + if err == nil { + return nil + } + + if !isTransientMEIBusy(err) || attempt == maxRetries { + return err + } + + wait := time.Duration(attempt+1) * time.Duration(utils.HeciConnectRetryBackoff) * time.Millisecond + log.Warnf("LME channel open busy, retry %d/%d", attempt+1, maxRetries) + time.Sleep(wait) + } + + return err +} + +func isTransientMEIBusy(err error) bool { + if err == nil { + return false + } + + errMsg := strings.ToLower(err.Error()) + + return strings.Contains(errMsg, "device or resource busy") || + strings.Contains(errMsg, "resource busy") || + strings.Contains(errMsg, "no such device") || + strings.Contains(errMsg, "device unavailable") +} + func (e Executor) HandleDataFromLM(data []byte) { if len(data) > 0 { log.Debug("received data from LMX") diff --git a/internal/rps/rps.go b/internal/rps/rps.go index a03e5229..79445f78 100644 --- a/internal/rps/rps.go +++ b/internal/rps/rps.go @@ -229,6 +229,7 @@ func (amt *AMTActivationServer) ProcessMessage(message []byte) ([]byte, bool, er case "error": err := json.Unmarshal([]byte(activation.Message), &statusMessage) errMessage := activation.Message + if err == nil { log.Error(statusMessage.Status) errMessage = statusMessage.Status diff --git a/pkg/utils/constants.go b/pkg/utils/constants.go index 2588385a..f4fe7183 100644 --- a/pkg/utils/constants.go +++ b/pkg/utils/constants.go @@ -28,15 +28,15 @@ const ( MPSServerMaxLength = 256 // LMSConnectionTimeout is the maximum wait for LMS TCP connection setup. - LMSConnectionTimeout = 10 // seconds - LMSDialerTimeout = 5 // seconds - HeciReadTimeout = 30 // seconds - HeciRetryDelay = 3000 // milliseconds + LMSConnectionTimeout = 6 // seconds + LMSDialerTimeout = 4 // seconds + HeciReadTimeout = 15 // seconds + HeciRetryDelay = 2000 // milliseconds HeciReinitDelay = 500 // milliseconds - HeciConnectRetryBackoff = 500 // milliseconds + HeciConnectRetryBackoff = 300 // milliseconds LMETimerTimeout = 10 // seconds WebSocketTimeout = 60 // seconds - AMTResponseTimeout = 30 // seconds + AMTResponseTimeout = 10 // seconds HelpHeader = "\nRemote Provisioning Client (RPC) - used for activation, deactivation, maintenance and status of AMT\n\n" From 2efae28b9e370d8b26b574c7454026a90b88e9c0 Mon Sep 17 00:00:00 2001 From: Nabendu Maiti Date: Fri, 20 Mar 2026 12:07:47 +0000 Subject: [PATCH 2/2] fix: lintissues and copilot issues fixed lint and copilot issues Signed-off-by: Nabendu Maiti --- internal/commands/activate/local.go | 2 ++ internal/lm/engine.go | 1 + internal/local/amt/localTransport.go | 32 ++++++++++++++++++++++++---- internal/rps/executor.go | 14 +++++++++--- 4 files changed, 42 insertions(+), 7 deletions(-) diff --git a/internal/commands/activate/local.go b/internal/commands/activate/local.go index 7425af1e..193c9210 100644 --- a/internal/commands/activate/local.go +++ b/internal/commands/activate/local.go @@ -1046,6 +1046,7 @@ func (service *LocalActivationService) handleSetupErrorWithControlModeVerificati return utils.ActivationFailedControlMode } +// TODO: Move retry logic in wsman pkg func (service *LocalActivationService) getGeneralSettingsWithRetry() (general.Response, error) { const maxRetries = 3 @@ -1059,6 +1060,7 @@ func (service *LocalActivationService) getGeneralSettingsWithRetry() (general.Re lastErr = err errText := strings.ToLower(err.Error()) + transientBusy := strings.Contains(errText, "device or resource busy") || strings.Contains(errText, "resource busy") || strings.Contains(errText, "no such device") || diff --git a/internal/lm/engine.go b/internal/lm/engine.go index 4d048c12..2df23b38 100644 --- a/internal/lm/engine.go +++ b/internal/lm/engine.go @@ -259,6 +259,7 @@ func (lme *LMEConnection) Listen() { } } +// TODO: Optimize/test changes if wsman pkg can handle it func (lme *LMEConnection) processWithLocalTimerOverride(message []byte) bytes.Buffer { processed := apf.Process(message, lme.Session) diff --git a/internal/local/amt/localTransport.go b/internal/local/amt/localTransport.go index 0663a553..904e1977 100644 --- a/internal/local/amt/localTransport.go +++ b/internal/local/amt/localTransport.go @@ -95,7 +95,16 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { channelOpenTimeout = utils.AMTResponseTimeout * time.Second } - channelOpenTimer := time.After(channelOpenTimeout) + channelOpenTimer := time.NewTimer(channelOpenTimeout) + + defer func() { + if !channelOpenTimer.Stop() { + select { + case <-channelOpenTimer.C: + default: + } + } + }() channelOpenDone := make(chan struct{}) @@ -107,7 +116,13 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { select { case <-channelOpenDone: - case <-channelOpenTimer: + case <-channelOpenTimer.C: + // Close the LME connection so the goroutine waiting on WaitGroup can + // unblock and the next request can start with a clean state. + if closeErr := l.Close(); closeErr != nil { + logrus.Errorf("failed to close LME connection after channel open timeout: %v", closeErr) + } + return nil, fmt.Errorf("timeout waiting for LME channel open confirmation after %s", channelOpenTimeout) } @@ -134,7 +149,16 @@ func (l *LocalTransport) RoundTrip(r *http.Request) (*http.Response, error) { responseTimeout := utils.AMTResponseTimeout * time.Second - responseTimer := time.After(responseTimeout) + responseTimer := time.NewTimer(responseTimeout) + + defer func() { + if !responseTimer.Stop() { + select { + case <-responseTimer.C: + default: + } + } + }() Loop: for { @@ -155,7 +179,7 @@ Loop: } break Loop - case <-responseTimer: + case <-responseTimer.C: respErr = fmt.Errorf("timeout waiting for LME response after %s", responseTimeout) break Loop diff --git a/internal/rps/executor.go b/internal/rps/executor.go index 65800c15..2259fc1f 100644 --- a/internal/rps/executor.go +++ b/internal/rps/executor.go @@ -238,8 +238,16 @@ func (e *Executor) HandleDataFromRPS(dataFromServer []byte) bool { return true } - responseTimeout := utils.AMTResponseTimeout * time.Second - responseTimer := time.After(responseTimeout) + responseTimer := time.NewTimer(utils.AMTResponseTimeout * time.Second) + + defer func() { + if !responseTimer.Stop() { + select { + case <-responseTimer.C: + default: + } + } + }() for { select { @@ -266,7 +274,7 @@ func (e *Executor) HandleDataFromRPS(dataFromServer []byte) bool { return true } - case <-responseTimer: + case <-responseTimer.C: // Timeout waiting for response from AMT/LME // This indicates AMT is not responding - treat as an error log.Error("Timeout waiting for LME response - AMT not responding")