diff --git a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java index 117d44a2f0a..2a224f960e6 100644 --- a/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java +++ b/server/src/main/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServices.java @@ -95,8 +95,10 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.emptySet; +import static java.util.Collections.singletonList; import static java.util.Optional.ofNullable; import static org.cloudfoundry.identity.uaa.oauth.client.ClientConstants.REQUIRED_USER_GROUPS; import static org.cloudfoundry.identity.uaa.oauth.openid.IdToken.ACR_VALUES_KEY; @@ -143,7 +145,7 @@ public class UaaTokenServices implements AuthorizationServerTokenServices, Resou private final TokenPolicy tokenPolicy; private final RevocableTokenProvisioning tokenProvisioning; private Set excludedClaims; - private UaaTokenEnhancer uaaTokenEnhancer; + private List uaaTokenEnhancers = new ArrayList<>(); private final IdTokenCreator idTokenCreator; private final RefreshTokenCreator refreshTokenCreator; private TokenEndpointBuilder tokenEndpointBuilder; @@ -192,8 +194,13 @@ public void setExcludedClaims(Set excludedClaims) { } @Autowired(required = false) + public void setUaaTokenEnhancers(List uaaTokenEnhancers) { + this.uaaTokenEnhancers = new ArrayList<>(uaaTokenEnhancers == null ? emptyList() : uaaTokenEnhancers); + } + + @Deprecated public void setUaaTokenEnhancer(UaaTokenEnhancer uaaTokenEnhancer) { - this.uaaTokenEnhancer = uaaTokenEnhancer; + this.setUaaTokenEnhancers(uaaTokenEnhancer == null ? emptyList() : singletonList(uaaTokenEnhancer)); } @Override @@ -349,7 +356,7 @@ Claims getClaims(Map refreshTokenClaims) { private Map getAdditionalRootClaims(Map refreshTokenClaims) { Map additionalRootClaims = new HashMap<>(); - if (uaaTokenEnhancer != null) { + if (!uaaTokenEnhancers.isEmpty()) { refreshTokenClaims.entrySet() .stream() .filter(entry -> !NON_ADDITIONAL_ROOT_CLAIMS.contains(entry.getKey())) @@ -625,8 +632,16 @@ public OAuth2AccessToken createAccessToken(OAuth2Authentication authentication) boolean isRefreshTokenRevocable = isAccessTokenRevocable || OPAQUE.getStringValue().equals(getActiveTokenPolicy().getRefreshTokenFormat()); Map additionalRootClaims = null; - if (uaaTokenEnhancer != null) { - additionalRootClaims = new HashMap<>(uaaTokenEnhancer.enhance(emptyMap(), authentication)); + if (!uaaTokenEnhancers.isEmpty()) { + additionalRootClaims = new HashMap<>(); + for (UaaTokenEnhancer enhancer : uaaTokenEnhancers) { + if (enhancer != null) { + Map claims = enhancer.enhance(additionalRootClaims, authentication); + if (claims != null) { + additionalRootClaims.putAll(claims); + } + } + } } String clientAuthentication = getAuthenticationMethod(oAuth2Request); diff --git a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/DeprecatedUaaTokenServicesTests.java b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/DeprecatedUaaTokenServicesTests.java index 7f087ef22c0..336c0201934 100644 --- a/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/DeprecatedUaaTokenServicesTests.java +++ b/server/src/test/java/org/cloudfoundry/identity/uaa/oauth/DeprecatedUaaTokenServicesTests.java @@ -582,6 +582,54 @@ void createOpaqueAccessTokenForAClient(TestTokenEnhancer enhancer) { assertThat(accessToken.getRefreshToken()).isNull(); } + @MethodSource("data") + @ParameterizedTest(name = "{index}: {0}") + void multipleTokenEnhancersAreSupported(TestTokenEnhancer enhancer) { + initDeprecatedUaaTokenServicesTests(enhancer); + UaaTokenEnhancer enhancer1 = new UaaTokenEnhancer() { + @Override + public Map getExternalAttributes(OAuth2Authentication authentication) { + return Map.of(); + } + + @Override + public Map enhance(Map claims, OAuth2Authentication authentication) { + return Map.of("claim1", "value1"); + } + }; + + UaaTokenEnhancer enhancer2 = new UaaTokenEnhancer() { + @Override + public Map getExternalAttributes(OAuth2Authentication authentication) { + return Map.of(); + } + + @Override + public Map enhance(Map claims, OAuth2Authentication authentication) { + if (claims.containsKey("claim1")) { + return Map.of("claim2", claims.get("claim1") + "_modified"); + } + return Map.of("claim2", "value2"); + } + }; + + tokenServices.setUaaTokenEnhancers(java.util.Arrays.asList(enhancer1, enhancer2)); + + AuthorizationRequest authorizationRequest = new AuthorizationRequest(CLIENT_ID, tokenSupport.clientScopes); + authorizationRequest.setResourceIds(new java.util.HashSet<>(tokenSupport.resourceIds)); + authorizationRequest.setRequestParameters(new java.util.HashMap<>()); + OAuth2Authentication authentication = new OAuth2Authentication(authorizationRequest.createOAuth2Request(), null); + + OAuth2AccessToken accessToken = tokenServices.createAccessToken(authentication); + + String jwt = accessToken.getValue(); + org.cloudfoundry.identity.uaa.oauth.jwt.Jwt parsedToken = org.cloudfoundry.identity.uaa.oauth.jwt.JwtHelper.decode(jwt); + Map claims = org.cloudfoundry.identity.uaa.util.JsonUtils.readValue(parsedToken.getClaims(), new com.fasterxml.jackson.core.type.TypeReference>() {}); + + assertThat(claims).containsEntry("claim1", "value1"); + assertThat(claims).containsEntry("claim2", "value1_modified"); + } + @MethodSource("data") @ParameterizedTest(name = "{index}: {0}") void createAccessTokenForAClientInAnotherIdentityZone(TestTokenEnhancer enhancer) { diff --git a/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java b/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java index 99ca2ba04eb..fd65d147f81 100644 --- a/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java +++ b/uaa/src/test/java/org/cloudfoundry/identity/uaa/oauth/UaaTokenServicesTests.java @@ -224,6 +224,110 @@ void ensureAnIdTokenIsNotReturned(String grantType) { } } + @Nested + @DisplayName("when multiple token enhancers are provided") + @DefaultTestContext + @TestPropertySource(properties = {"uaa.url=https://uaa.some.test.domain.com:555/uaa"}) + class WhenMultipleTokenEnhancersAreProvided { + + @DisplayName("claims from all enhancers are merged into the token") + @ParameterizedTest + @ValueSource(strings = {GRANT_TYPE_PASSWORD, GRANT_TYPE_AUTHORIZATION_CODE}) + void claimsAreMerged(String grantType) { + UaaTokenEnhancer enhancer1 = new UaaTokenEnhancer() { + @Override + public Map getExternalAttributes(OAuth2Authentication authentication) { + return Map.of(); + } + + @Override + public Map enhance(Map claims, OAuth2Authentication authentication) { + return Map.of("claim1", "value1"); + } + }; + + UaaTokenEnhancer enhancer2 = new UaaTokenEnhancer() { + @Override + public Map getExternalAttributes(OAuth2Authentication authentication) { + return Map.of(); + } + + @Override + public Map enhance(Map claims, OAuth2Authentication authentication) { + return Map.of("claim2", "value2"); + } + }; + + tokenServices.setUaaTokenEnhancers(Arrays.asList(enhancer1, enhancer2)); + + try { + AuthorizationRequest authorizationRequest = constructAuthorizationRequest(clientId, grantType, "openid", "user_attributes"); + OAuth2Authentication auth2Authentication = constructUserAuthenticationFromAuthzRequest(authorizationRequest, "admin", "uaa"); + + CompositeToken accessToken = (CompositeToken) tokenServices.createAccessToken(auth2Authentication); + + String jwt = accessToken.getValue(); + Jwt parsedToken = JwtHelper.decode(jwt); + Map claims = JsonUtils.readValue(parsedToken.getClaims(), new TypeReference>() {}); + + assertThat(claims).containsEntry("claim1", "value1"); + assertThat(claims).containsEntry("claim2", "value2"); + } finally { + tokenServices.setUaaTokenEnhancers(new ArrayList<>()); + } + } + + @DisplayName("claims are passed to subsequent enhancers") + @ParameterizedTest + @ValueSource(strings = {GRANT_TYPE_PASSWORD, GRANT_TYPE_AUTHORIZATION_CODE}) + void claimsArePassedToSubsequentEnhancers(String grantType) { + UaaTokenEnhancer enhancer1 = new UaaTokenEnhancer() { + @Override + public Map getExternalAttributes(OAuth2Authentication authentication) { + return Map.of(); + } + + @Override + public Map enhance(Map claims, OAuth2Authentication authentication) { + return Map.of("claim1", "value1"); + } + }; + + UaaTokenEnhancer enhancer2 = new UaaTokenEnhancer() { + @Override + public Map getExternalAttributes(OAuth2Authentication authentication) { + return Map.of(); + } + + @Override + public Map enhance(Map claims, OAuth2Authentication authentication) { + if (claims.containsKey("claim1")) { + return Map.of("claim2", claims.get("claim1") + "_modified"); + } + return Map.of("claim2", "value2"); + } + }; + + tokenServices.setUaaTokenEnhancers(Arrays.asList(enhancer1, enhancer2)); + + try { + AuthorizationRequest authorizationRequest = constructAuthorizationRequest(clientId, grantType, "openid", "user_attributes"); + OAuth2Authentication auth2Authentication = constructUserAuthenticationFromAuthzRequest(authorizationRequest, "admin", "uaa"); + + CompositeToken accessToken = (CompositeToken) tokenServices.createAccessToken(auth2Authentication); + + String jwt = accessToken.getValue(); + Jwt parsedToken = JwtHelper.decode(jwt); + Map claims = JsonUtils.readValue(parsedToken.getClaims(), new TypeReference>() {}); + + assertThat(claims).containsEntry("claim1", "value1"); + assertThat(claims).containsEntry("claim2", "value1_modified"); + } finally { + tokenServices.setUaaTokenEnhancers(new ArrayList<>()); + } + } + } + @Nested @DisplayName("when the hasn't approved the 'openid' scope") @DefaultTestContext