44 "encoding/json"
55 "fmt"
66 "log"
7+ "net/url"
8+ "strings"
79
810 "github.com/aws/aws-sdk-go/aws"
911 "github.com/aws/aws-sdk-go/aws/awserr"
@@ -12,48 +14,39 @@ import (
1214 "github.com/aws/aws-sdk-go/aws/endpoints"
1315 "github.com/aws/aws-sdk-go/aws/session"
1416 "github.com/aws/aws-sdk-go/service/secretsmanager"
17+ "github.com/aws/aws-sdk-go/service/sts"
18+
1519 "github.com/mongodb/terraform-provider-mongodbatlas/internal/config"
1620)
1721
1822const (
19- endPointSTSDefault = "https://sts.amazonaws.com"
23+ endPointSTSHostnameDefault = "sts.amazonaws.com"
24+ DefaultRegionSTS = "us-east-1"
25+ minSegmentsForSTSRegionalHost = 4
2026)
2127
2228func configureCredentialsSTS (cfg * config.Config , secret , region , awsAccessKeyID , awsSecretAccessKey , awsSessionToken , endpoint string ) (config.Config , error ) {
23- ep , err := endpoints .GetSTSRegionalEndpoint ("regional" )
24- if err != nil {
25- log .Printf ("GetSTSRegionalEndpoint error: %s" , err )
26- return * cfg , err
27- }
28-
2929 defaultResolver := endpoints .DefaultResolver ()
30- stsCustResolverFn := func (service , region string , optFns ... func (* endpoints.Options )) (endpoints.ResolvedEndpoint , error ) {
31- if service == endpoints .StsServiceID {
32- if endpoint == "" {
33- return endpoints.ResolvedEndpoint {
34- URL : endPointSTSDefault ,
35- SigningRegion : region ,
36- }, nil
30+ stsCustResolverFn := func (service , _ string , optFns ... func (* endpoints.Options )) (endpoints.ResolvedEndpoint , error ) {
31+ if service == sts .EndpointsID {
32+ resolved , err := ResolveSTSEndpoint (endpoint , region )
33+ if err != nil {
34+ return endpoints.ResolvedEndpoint {}, err
3735 }
38- return endpoints.ResolvedEndpoint {
39- URL : endpoint ,
40- SigningRegion : region ,
41- }, nil
36+ return resolved , nil
4237 }
43-
4438 return defaultResolver .EndpointFor (service , region , optFns ... )
4539 }
4640
4741 sess := session .Must (session .NewSession (& aws.Config {
48- Region : aws .String (region ),
49- Credentials : credentials .NewStaticCredentials (awsAccessKeyID , awsSecretAccessKey , awsSessionToken ),
50- STSRegionalEndpoint : ep ,
51- EndpointResolver : endpoints .ResolverFunc (stsCustResolverFn ),
42+ Region : aws .String (region ),
43+ Credentials : credentials .NewStaticCredentials (awsAccessKeyID , awsSecretAccessKey , awsSessionToken ),
44+ EndpointResolver : endpoints .ResolverFunc (stsCustResolverFn ),
5245 }))
5346
5447 creds := stscreds .NewCredentials (sess , cfg .AssumeRole .RoleARN )
5548
56- _ , err = sess .Config .Credentials .Get ()
49+ _ , err : = sess .Config .Credentials .Get ()
5750 if err != nil {
5851 log .Printf ("Session get credentials error: %s" , err )
5952 return * cfg , err
@@ -87,6 +80,45 @@ func configureCredentialsSTS(cfg *config.Config, secret, region, awsAccessKeyID,
8780 return * cfg , nil
8881}
8982
83+ func DeriveSTSRegionFromEndpoint (ep string ) string {
84+ if ep == "" {
85+ return ""
86+ }
87+ u , err := url .Parse (ep )
88+ if err != nil {
89+ return DefaultRegionSTS
90+ }
91+ host := u .Hostname () // valid values: sts.us-west-2.amazonaws.com or sts.amazonaws.com
92+
93+ if host == endPointSTSHostnameDefault {
94+ return DefaultRegionSTS
95+ }
96+
97+ parts := strings .Split (host , "." )
98+ if len (parts ) >= minSegmentsForSTSRegionalHost && parts [0 ] == "sts" {
99+ return parts [1 ]
100+ }
101+ return DefaultRegionSTS
102+ }
103+
104+ func ResolveSTSEndpoint (stsEndpoint , secretsRegion string ) (endpoints.ResolvedEndpoint , error ) {
105+ ep := stsEndpoint
106+ if ep == "" {
107+ r := secretsRegion
108+ if r == "" {
109+ r = DefaultRegionSTS
110+ }
111+ ep = fmt .Sprintf ("https://sts.%s.amazonaws.com/" , r )
112+ }
113+
114+ signingRegion := DeriveSTSRegionFromEndpoint (ep )
115+
116+ return endpoints.ResolvedEndpoint {
117+ URL : ep ,
118+ SigningRegion : signingRegion ,
119+ }, nil
120+ }
121+
90122func secretsManagerGetSecretValue (sess * session.Session , creds * aws.Config , secret string ) (string , error ) {
91123 svc := secretsmanager .New (sess , creds )
92124 input := & secretsmanager.GetSecretValueInput {
0 commit comments