From d6305b7f1b8e30ac254ba6d57d6c77c4e0cb4694 Mon Sep 17 00:00:00 2001 From: Nabendu Maiti Date: Mon, 2 Mar 2026 06:31:41 +0000 Subject: [PATCH] fix: resolve local LMS enforcement issue for AMT>19 Added LMS certificate check for loopback connections. Return local LMS TLS status to AMT cloud/console. Signed-off-by: Nabendu Maiti --- internal/certs/lmsTls.go | 43 +++++++++++++++++++++++--- internal/certs/lmsTls_test.go | 25 +++++++++++++++ internal/cli/cli.go | 1 + internal/commands/activate/activate.go | 3 ++ internal/commands/activate/remote.go | 1 + internal/commands/deactivate.go | 1 + internal/commands/deactivate_test.go | 25 +++++++++++---- internal/commands/shared.go | 17 +++++----- internal/rps/executor.go | 8 ++--- internal/rps/message.go | 2 ++ internal/rps/message_test.go | 37 ++++++++++++++++++++++ internal/rps/rps.go | 2 +- internal/rps/types.go | 2 +- 13 files changed, 142 insertions(+), 25 deletions(-) diff --git a/internal/certs/lmsTls.go b/internal/certs/lmsTls.go index f2780b6dc..c453f638c 100644 --- a/internal/certs/lmsTls.go +++ b/internal/certs/lmsTls.go @@ -5,6 +5,7 @@ package certs import ( + "bytes" "crypto/sha256" "crypto/tls" "crypto/x509" @@ -18,11 +19,14 @@ import ( // generates a TLS configuration based on the provided mode. func GetTLSConfig(mode *int, amtCertInfo *amt.SecureHBasedResponse, skipAMTCertCheck bool) *tls.Config { - tlsConfig := &tls.Config{} - - tlsConfig.InsecureSkipVerify = skipAMTCertCheck + tlsConfig := &tls.Config{ + InsecureSkipVerify: skipAMTCertCheck, + } if *mode == 0 { // pre-provisioning mode + // Pre-provisioning uses AMT loopback/self-signed TLS; allow handshake and + // enforce certificate validation in VerifyPeerCertificate. + tlsConfig.InsecureSkipVerify = true tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { if skipAMTCertCheck { return nil @@ -32,7 +36,11 @@ func GetTLSConfig(mode *int, amtCertInfo *amt.SecureHBasedResponse, skipAMTCertC } } else { // default tls config if device is in ACM or CCM - log.Trace("Setting default TLS Config for ACM/CCM mode") + if skipAMTCertCheck { + log.Trace("Skipping AMT certificate verification for ACM/CCM mode (loopback TLS)") + } else { + log.Trace("Using default TLS config for ACM/CCM mode") + } } return tlsConfig @@ -62,7 +70,7 @@ func VerifyCertificates(rawCerts [][]byte, mode *int, amtCertInfo *amt.SecureHBa return err } - log.Infof("Cert[%d]: Subject=%s, Issuer=%s, EKU=%v", i, cert.Subject, cert.Issuer, cert.ExtKeyUsage) + log.Tracef("Cert[%d]: Subject=%s, Issuer=%s, EKU=%v", i, cert.Subject, cert.Issuer, cert.ExtKeyUsage) parsedCerts = append(parsedCerts, cert) @@ -84,6 +92,31 @@ func VerifyCertificates(rawCerts [][]byte, mode *int, amtCertInfo *amt.SecureHBa return nil case selfSignedChainLength: + // On AMT 19+ loopback TLS, the LMS/AMT certificate is typically a single + // self-signed certificate that is not rooted in the system trust store. + // In pre-provisioning mode (mode == 0), accept this only when the leaf + // certificate matches AMT loopback expectations. + if mode != nil && *mode == 0 { + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + log.Error("Failed to parse self-signed AMT loopback certificate:", err) + + return err + } + + if !bytes.Equal(cert.RawSubject, cert.RawIssuer) { + return errors.New("single AMT loopback certificate is not self-signed") + } + + if err := VerifyLeafCertificate(cert, amtCertInfo); err != nil { + return err + } + + log.Trace("Accepting self-signed AMT loopback certificate in pre-provisioning mode") + + return nil + } + return HandleAMTTransition(mode) } diff --git a/internal/certs/lmsTls_test.go b/internal/certs/lmsTls_test.go index 0db9eeea4..18a24efb5 100644 --- a/internal/certs/lmsTls_test.go +++ b/internal/certs/lmsTls_test.go @@ -71,6 +71,11 @@ func TestGetTLSConfig(t *testing.T) { assert.True(t, tlsConfig.InsecureSkipVerify) assert.NotNil(t, tlsConfig.VerifyPeerCertificate) + tlsConfig = GetTLSConfig(&mode, nil, false) + assert.NotNil(t, tlsConfig) + assert.True(t, tlsConfig.InsecureSkipVerify) + assert.NotNil(t, tlsConfig.VerifyPeerCertificate) + mode = 1 tlsConfig = GetTLSConfig(&mode, nil, true) assert.NotNil(t, tlsConfig) @@ -299,3 +304,23 @@ func TestVerifyFullChain(t *testing.T) { }) } } + +func TestVerifyCertificates_SingleCertPreProvisioning(t *testing.T) { + t.Run("accepts allowed self-signed leaf CN", func(t *testing.T) { + mode := 0 + leafTemplate := createCertTemplate("AMT RCFG", false, []string{"Leaf OU"}) + leafCert, _ := createTestCert(t, leafTemplate, nil, nil) + + err := VerifyCertificates([][]byte{leafCert.Raw}, &mode, nil) + assert.NoError(t, err) + }) + + t.Run("rejects invalid self-signed leaf CN", func(t *testing.T) { + mode := 0 + leafTemplate := createCertTemplate("Invalid CN", false, []string{"Leaf OU"}) + leafCert, _ := createTestCert(t, leafTemplate, nil, nil) + + err := VerifyCertificates([][]byte{leafCert.Raw}, &mode, nil) + assert.Error(t, err) + }) +} diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 31ac37645..4d785ed61 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -248,6 +248,7 @@ func ExecuteWithAMT(args []string, amtCommand amt.Interface) error { appCtx := &commands.Context{ AMTCommand: amtCommand, + LocalTLSEnforced: false, LogLevel: cli.LogLevel, JsonOutput: cli.JsonOutput, TableOutput: cli.TableOutput, diff --git a/internal/commands/activate/activate.go b/internal/commands/activate/activate.go index 70d05ab9a..a47eedc29 100644 --- a/internal/commands/activate/activate.go +++ b/internal/commands/activate/activate.go @@ -258,6 +258,9 @@ func (cmd *ActivateCmd) Run(ctx *commands.Context) error { // runRemoteActivation executes remote activation using the remote service func (cmd *ActivateCmd) runRemoteActivation(ctx *commands.Context) error { + // Propagate local TLS enforcement status detected in AMTBaseCmd.AfterApply + ctx.LocalTLSEnforced = cmd.LocalTLSEnforced + // Create remote activation command with current flags remoteCmd := RemoteActivateCmd{ URL: cmd.URL, diff --git a/internal/commands/activate/remote.go b/internal/commands/activate/remote.go index f3114937d..e8de3db95 100644 --- a/internal/commands/activate/remote.go +++ b/internal/commands/activate/remote.go @@ -169,6 +169,7 @@ func (service *RemoteActivationService) requestActivation(deviceInfo map[string] Verbose: service.context.Verbose, SkipCertCheck: service.context.SkipCertCheck, SkipAmtCertCheck: service.context.SkipAMTCertCheck, + LocalTLSEnforced: service.context.LocalTLSEnforced, ControlMode: service.context.ControlMode, TenantID: service.context.TenantID, Password: service.context.AMTPassword, diff --git a/internal/commands/deactivate.go b/internal/commands/deactivate.go index f1e127398..25d58157b 100644 --- a/internal/commands/deactivate.go +++ b/internal/commands/deactivate.go @@ -130,6 +130,7 @@ func (cmd *DeactivateCmd) executeRemoteDeactivate(ctx *Context) error { Verbose: ctx.Verbose, SkipCertCheck: ctx.SkipCertCheck, SkipAmtCertCheck: ctx.SkipAMTCertCheck, + LocalTLSEnforced: cmd.LocalTLSEnforced, Force: cmd.Force, TenantID: ctx.TenantID, } diff --git a/internal/commands/deactivate_test.go b/internal/commands/deactivate_test.go index 4741536d6..58c767f5f 100644 --- a/internal/commands/deactivate_test.go +++ b/internal/commands/deactivate_test.go @@ -574,27 +574,40 @@ func TestDeactivateCmd_ResolveGUID(t *testing.T) { // Test setupTLSConfig function func TestSetupTLSConfig(t *testing.T) { t.Run("TLS config with LocalTLSEnforced false", func(t *testing.T) { - cmd := &DeactivateCmd{} - cmd.LocalTLSEnforced = true - ctx := &Context{ControlMode: ControlModeACM} + cmd := &DeactivateCmd{ + AMTBaseCmd: AMTBaseCmd{ + ControlMode: ControlModeACM, + }, + } + cmd.LocalTLSEnforced = false + ctx := &Context{ + ControlMode: ControlModeACM, + SkipAMTCertCheck: true, // Should be ignored when not enforced + } tlsConfig := cmd.setupTLSConfig(ctx) assert.NotNil(t, tlsConfig) + // When TLS is not enforced locally, we expect default config which has InsecureSkipVerify false assert.False(t, tlsConfig.InsecureSkipVerify) }) t.Run("TLS config with LocalTLSEnforced true", func(t *testing.T) { - cmd := &DeactivateCmd{} + cmd := &DeactivateCmd{ + AMTBaseCmd: AMTBaseCmd{ + ControlMode: ControlModeACM, + }, + } cmd.LocalTLSEnforced = true ctx := &Context{ - SkipCertCheck: true, + SkipCertCheck: true, // Should be ignored by setupTLSConfig ControlMode: ControlModeACM, } tlsConfig := cmd.setupTLSConfig(ctx) assert.NotNil(t, tlsConfig) - // The actual config setup depends on the config.GetTLSConfig implementation + // When LocalTLSEnforced is true, we use SkipAMTCertCheck (which is false here) + assert.False(t, tlsConfig.InsecureSkipVerify) }) } diff --git a/internal/commands/shared.go b/internal/commands/shared.go index 52dfcb183..8b7f2da42 100644 --- a/internal/commands/shared.go +++ b/internal/commands/shared.go @@ -12,14 +12,15 @@ import ( // Context holds shared dependencies injected into commands type Context struct { - AMTCommand amt.Interface - ControlMode int - LogLevel string - JsonOutput bool - TableOutput bool - NoColor bool - Verbose bool - SkipCertCheck bool + AMTCommand amt.Interface + ControlMode int + LocalTLSEnforced bool + LogLevel string + JsonOutput bool + TableOutput bool + NoColor bool + Verbose bool + SkipCertCheck bool // SkipAMTCertCheck controls whether to skip TLS verification when connecting to AMT/LMS over TLS // This is distinct from SkipCertCheck which applies to remote RPS HTTPS/WSS connections. SkipAMTCertCheck bool diff --git a/internal/rps/executor.go b/internal/rps/executor.go index 0072be976..17170da6e 100644 --- a/internal/rps/executor.go +++ b/internal/rps/executor.go @@ -34,7 +34,7 @@ type Executor struct { type ExecutorConfig struct { URL string Proxy string - LocalTlsEnforced bool + LocalTLSEnforced bool SkipAmtCertCheck bool ControlMode int SkipCertCheck bool @@ -46,13 +46,13 @@ func NewExecutor(config ExecutorConfig) (Executor, error) { lmErrorChannel := make(chan error) port := utils.LMSPort - if config.LocalTlsEnforced { + if config.LocalTLSEnforced { port = utils.LMSTLSPort } client := Executor{ server: NewAMTActivationServer(config.URL, config.Proxy), - localManagement: lm.NewLMSConnection(utils.LMSAddress, port, config.LocalTlsEnforced, lmDataChannel, lmErrorChannel, config.ControlMode, config.SkipAmtCertCheck), + localManagement: lm.NewLMSConnection(utils.LMSAddress, port, config.LocalTLSEnforced, lmDataChannel, lmErrorChannel, config.ControlMode, config.SkipAmtCertCheck), data: lmDataChannel, errors: lmErrorChannel, waitGroup: &sync.WaitGroup{}, @@ -61,7 +61,7 @@ func NewExecutor(config ExecutorConfig) (Executor, error) { // TEST CONNECTION TO SEE IF LMS EXISTS err := client.localManagement.Connect() if err != nil { - if config.LocalTlsEnforced { + if config.LocalTLSEnforced { return client, utils.LMSConnectionFailed } // client.localManagement.Close() diff --git a/internal/rps/message.go b/internal/rps/message.go index dcd7dab9e..bb61fed48 100644 --- a/internal/rps/message.go +++ b/internal/rps/message.go @@ -57,6 +57,7 @@ type MessagePayload struct { CertificateHashes []string `json:"certHashes"` IPConfiguration IPConfiguration `json:"ipConfiguration"` HostnameInfo HostnameInfo `json:"hostnameInfo"` + LocalTLSEnforced bool `json:"localTlsEnforced,omitempty"` FriendlyName string `json:"friendlyName,omitempty"` } @@ -197,6 +198,7 @@ func (p Payload) CreateMessageRequest(req Request) (Message, error) { payload.IPConfiguration = req.IpConfiguration payload.HostnameInfo = req.HostnameInfo + payload.LocalTLSEnforced = req.LocalTLSEnforced if req.UUID != "" { if isKnownInvalidUUID(req.UUID) { diff --git a/internal/rps/message_test.go b/internal/rps/message_test.go index 25e0dd0d1..9ab9250c7 100644 --- a/internal/rps/message_test.go +++ b/internal/rps/message_test.go @@ -437,3 +437,40 @@ func TestCreateMessageRequestWithInvalidUUIDPattern(t *testing.T) { assert.Error(t, createErr) assert.Equal(t, utils.InvalidUUID, createErr) } + +func TestCreateMessageRequestLocalTLSEnforced(t *testing.T) { + tests := []struct { + name string + localTLSEnforced bool + expectEnforced bool + }{ + { + name: "true", + localTLSEnforced: true, + expectEnforced: true, + }, + { + name: "false", + localTLSEnforced: false, + expectEnforced: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + flags := Request{ + LocalTLSEnforced: tc.localTLSEnforced, + } + result, createErr := p.CreateMessageRequest(flags) + assert.NoError(t, createErr) + assert.NotEmpty(t, result.Payload) + decodedBytes, decodeErr := base64.StdEncoding.DecodeString(result.Payload) + assert.NoError(t, decodeErr) + + msgPayload := MessagePayload{} + jsonErr := json.Unmarshal(decodedBytes, &msgPayload) + assert.NoError(t, jsonErr) + assert.Equal(t, tc.expectEnforced, msgPayload.LocalTLSEnforced) + }) + } +} diff --git a/internal/rps/rps.go b/internal/rps/rps.go index 79445f781..496d7cd77 100644 --- a/internal/rps/rps.go +++ b/internal/rps/rps.go @@ -36,7 +36,7 @@ func ExecuteCommand(req *Request) error { config := ExecutorConfig{ URL: req.URL, Proxy: req.Proxy, - LocalTlsEnforced: req.LocalTlsEnforced, + LocalTLSEnforced: req.LocalTLSEnforced, SkipAmtCertCheck: req.SkipAmtCertCheck, ControlMode: req.ControlMode, SkipCertCheck: req.SkipCertCheck, diff --git a/internal/rps/types.go b/internal/rps/types.go index 77f69d625..053206f8b 100644 --- a/internal/rps/types.go +++ b/internal/rps/types.go @@ -38,7 +38,7 @@ type Request struct { // Connection and server parameters URL string Proxy string - LocalTlsEnforced bool + LocalTLSEnforced bool SkipAmtCertCheck bool ControlMode int SkipCertCheck bool