Skip to content

Commit

Permalink
Fixing a problem with potential dirty read of a token document on tok…
Browse files Browse the repository at this point in the history
…en refresh (#64031)

* Fixing a problem with potential dirty read of a token document.
Related to #59685

* Fixing a problem with potential dirty read of a token document.
Adding CreateTokenResult to hold authentication object

* Fixing a problem with potential dirty read of a token document.
Adding CreateTokenResult to hold authentication object

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
BigPandaToo and elasticmachine authored Oct 26, 2020
1 parent 6093518 commit c23c057
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ protected void doExecute(Task task, DelegatePkiAuthenticationRequest request,
ActionListener.wrap(authentication -> {
assert authentication != null : "authentication should never be null at this point";
tokenService.createOAuth2Tokens(authentication, delegateeAuthentication, Map.of(), false,
ActionListener.wrap(tuple -> {
ActionListener.wrap(tokenResult -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(new DelegatePkiAuthenticationResponse(tuple.v1(), expiresIn, authentication));
listener.onResponse(new DelegatePkiAuthenticationResponse(tokenResult.getAccessToken(), expiresIn,
authentication));
}, listener::onFailure));
}, e -> {
logger.debug((Supplier<?>) () -> new ParameterizedMessage("Delegated x509Token [{}] could not be authenticated",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ protected void doExecute(Task task, OpenIdConnectAuthenticateRequest request,
@SuppressWarnings("unchecked") final Map<String, Object> tokenMetadata = (Map<String, Object>) result.getMetadata()
.get(OpenIdConnectRealm.CONTEXT_TOKEN_DATA);
tokenService.createOAuth2Tokens(authentication, originatingAuthentication, tokenMetadata, true,
ActionListener.wrap(tuple -> {
ActionListener.wrap(tokenResult -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication, tuple.v1(), tuple.v2(), expiresIn));
listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication, tokenResult.getAccessToken(),
tokenResult.getRefreshToken(), expiresIn));
}, listener::onFailure));
}, e -> {
logger.debug(() -> new ParameterizedMessage("OpenIDConnectToken [{}] could not be authenticated", token), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListe
assert authentication != null : "authentication should never be null at this point";
final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA);
tokenService.createOAuth2Tokens(authentication, originatingAuthentication,
tokenMeta, true, ActionListener.wrap(tuple -> {
tokenMeta, true, ActionListener.wrap(tokenResult -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(
new SamlAuthenticateResponse(authentication, tuple.v1(), tuple.v2(), expiresIn));
new SamlAuthenticateResponse(authentication, tokenResult.getAccessToken(), tokenResult.getRefreshToken(),
expiresIn));
}, listener::onFailure));
}, e -> {
logger.debug(() -> new ParameterizedMessage("SamlToken [{}] could not be authenticated", saml), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ private void clearCredentialsFromRequest(GrantType grantType, CreateTokenRequest
private void createToken(GrantType grantType, CreateTokenRequest request, Authentication authentication, Authentication originatingAuth,
boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) {
tokenService.createOAuth2Tokens(authentication, originatingAuth, Collections.emptyMap(), includeRefreshToken,
ActionListener.wrap(tuple -> {
ActionListener.wrap(tokenResult -> {
final String scope = getResponseScopeValue(request.getScope());
final String base64AuthenticateResponse = (grantType == GrantType.KERBEROS) ? extractOutToken() : null;
final CreateTokenResponse response = new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope,
tuple.v2(), base64AuthenticateResponse, authentication);
final CreateTokenResponse response = new CreateTokenResponse(tokenResult.getAccessToken(),
tokenService.getExpirationDelay(), scope, tokenResult.getRefreshToken(), base64AuthenticateResponse,
authentication);
listener.onResponse(response);
}, listener::onFailure));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest;
Expand All @@ -31,13 +30,12 @@ public TransportRefreshTokenAction(TransportService transportService, ActionFilt

@Override
protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tokenResult -> {
final String scope = getResponseScopeValue(request.getScope());
tokenService.authenticateToken(new SecureString(tuple.v1()), ActionListener.wrap(authentication -> {
listener.onResponse(new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope, tuple.v2(), null,
authentication));
},
listener::onFailure));
final CreateTokenResponse response =
new CreateTokenResponse(tokenResult.getAccessToken(), tokenService.getExpirationDelay(), scope,
tokenResult.getRefreshToken(), null, tokenResult.getAuthentication());
listener.onResponse(response);
}, listener::onFailure));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ public TokenService(Settings settings, Clock clock, Client client, XPackLicenseS
* {@link #VERSION_TOKENS_INDEX_INTRODUCED} and to a specific security tokens index for later versions.
*/
public void createOAuth2Tokens(Authentication authentication, Authentication originatingClientAuth, Map<String, Object> metadata,
boolean includeRefreshToken, ActionListener<Tuple<String, String>> listener) {
boolean includeRefreshToken, ActionListener<CreateTokenResult> listener) {
// the created token is compatible with the oldest node version in the cluster
final Version tokenVersion = getTokenVersionCompatibility();
// tokens moved to a separate index in newer versions
Expand All @@ -273,7 +273,7 @@ public void createOAuth2Tokens(Authentication authentication, Authentication ori
//public for testing
public void createOAuth2Tokens(String accessToken, String refreshToken, Authentication authentication,
Authentication originatingClientAuth,
Map<String, Object> metadata, ActionListener<Tuple<String, String>> listener) {
Map<String, Object> metadata, ActionListener<CreateTokenResult> listener) {
// the created token is compatible with the oldest node version in the cluster
final Version tokenVersion = getTokenVersionCompatibility();
// tokens moved to a separate index in newer versions
Expand Down Expand Up @@ -306,12 +306,13 @@ public void createOAuth2Tokens(String accessToken, String refreshToken, Authenti
* @param authentication The authentication object representing the user for which the tokens are created
* @param originatingClientAuth The authentication object representing the client that called the related API
* @param metadata A map with metadata to be stored in the token document
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* @param listener The listener to call upon completion with a {@link CreateTokenResult} containing the
* serialized access token, serialized refresh token and authentication for which the token is created
* as these will be returned to the client
*/
private void createOAuth2Tokens(String accessToken, String refreshToken, Version tokenVersion, SecurityIndexManager tokensIndex,
Authentication authentication, Authentication originatingClientAuth, Map<String, Object> metadata,
ActionListener<Tuple<String, String>> listener) {
ActionListener<CreateTokenResult> listener) {
assert accessToken.length() == TOKEN_LENGTH : "We assume token ids have a fixed length for nodes of a certain version."
+ " When changing the token length, be careful that the inferences about its length still hold.";
ensureEnabled();
Expand Down Expand Up @@ -351,12 +352,13 @@ private void createOAuth2Tokens(String accessToken, String refreshToken, Version
final String versionedRefreshToken = refreshToken != null
? prependVersionAndEncodeRefreshToken(tokenVersion, refreshToken)
: null;
listener.onResponse(new Tuple<>(versionedAccessToken, versionedRefreshToken));
listener.onResponse(new CreateTokenResult(versionedAccessToken, versionedRefreshToken,
authentication));
} else {
// prior versions of the refresh token are not version-prepended, as nodes on those
// versions don't expect it.
// Such nodes might exist in a mixed cluster during a rolling upgrade.
listener.onResponse(new Tuple<>(versionedAccessToken, refreshToken));
listener.onResponse(new CreateTokenResult(versionedAccessToken, refreshToken,authentication));
}
} else {
listener.onFailure(traceLog("create token",
Expand Down Expand Up @@ -859,10 +861,11 @@ private void indexInvalidation(Collection<String> tokenIds, SecurityIndexManager
* Called by the transport action in order to start the process of refreshing a token.
*
* @param refreshToken The refresh token as provided by the client
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* @param listener The listener to call upon completion with a {@link CreateTokenResult} containing the
* serialized access token, serialized refresh token and authentication for which the token is created
* as these will be returned to the client
*/
public void refreshToken(String refreshToken, ActionListener<Tuple<String, String>> listener) {
public void refreshToken(String refreshToken, ActionListener<CreateTokenResult> listener) {
ensureEnabled();
final Instant refreshRequested = clock.instant();
final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
Expand Down Expand Up @@ -995,7 +998,7 @@ private void findTokenFromRefreshToken(String refreshToken, SecurityIndexManager
*/
private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Object> source, long seqNo, long primaryTerm,
Authentication clientAuth, Iterator<TimeValue> backoff, Instant refreshRequested,
ActionListener<Tuple<String, String>> listener) {
ActionListener<CreateTokenResult> listener) {
logger.debug("Attempting to refresh token stored in token document [{}]", tokenDocId);
final Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex));
final Tuple<RefreshTokenStatus, Optional<ElasticsearchSecurityException>> checkRefreshResult;
Expand All @@ -1014,7 +1017,9 @@ private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Ob
if (refreshTokenStatus.isRefreshed()) {
logger.debug("Token document [{}] was recently refreshed, when a new token document was generated. Reusing that result.",
tokenDocId);
decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, refreshedTokenIndex, listener);
final Tuple<UserToken, String> parsedTokens = parseTokensFromDocument(source, null);
Authentication authentication = parsedTokens.v1().getAuthentication();
decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, refreshedTokenIndex, authentication, listener);
} else {
final String newAccessTokenString = UUIDs.randomBase64UUID();
final String newRefreshTokenString = UUIDs.randomBase64UUID();
Expand Down Expand Up @@ -1126,11 +1131,13 @@ public void onFailure(Exception e) {
* @param refreshTokenStatus The {@link RefreshTokenStatus} containing information about the superseding tokens as retrieved from the
* index
* @param tokensIndex the manager for the index where the tokens are stored
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* @param authentication The authentication object representing the user for which the tokens are created
* @param listener The listener to call upon completion with a {@link CreateTokenResult} containing the
* serialized access token, serialized refresh token and authentication for which the token is created
* as these will be returned to the client
*/
void decryptAndReturnSupersedingTokens(String refreshToken, RefreshTokenStatus refreshTokenStatus, SecurityIndexManager tokensIndex,
ActionListener<Tuple<String, String>> listener) {
Authentication authentication, ActionListener<CreateTokenResult> listener) {

final byte[] iv = Base64.getDecoder().decode(refreshTokenStatus.getIv());
final byte[] salt = Base64.getDecoder().decode(refreshTokenStatus.getSalt());
Expand Down Expand Up @@ -1166,8 +1173,10 @@ public void onResponse(GetResponse response) {
if (response.isExists()) {
try {
listener.onResponse(
new Tuple<>(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(), decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1])));
new CreateTokenResult(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(),
decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1]),
authentication));
} catch (GeneralSecurityException | IOException e) {
logger.warn("Could not format stored superseding token values", e);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
Expand Down Expand Up @@ -1910,6 +1919,30 @@ boolean isExpirationInProgress() {
return expiredTokenRemover.isExpirationInProgress();
}

public static final class CreateTokenResult {
private final String accessToken;
private final String refreshToken;
private final Authentication authentication;

public CreateTokenResult(String accessToken, String refreshToken, Authentication authentication) {
this.accessToken = accessToken;
this.refreshToken = refreshToken;
this.authentication = authentication;
}

public String getAccessToken() {
return accessToken;
}

public String getRefreshToken() {
return refreshToken;
}

public Authentication getAuthentication() {
return authentication;
}
}

private class KeyComputingRunnable extends AbstractRunnable {

private final BytesKey decodedSalt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
Expand Down Expand Up @@ -203,11 +202,11 @@ public void testLogoutInvalidatesTokens() throws Exception {
final Authentication authentication = new Authentication(user, realmRef, null, null, Authentication.AuthenticationType.REALM,
tokenMetadata);

final PlainActionFuture<Tuple<String, String>> future = new PlainActionFuture<>();
final PlainActionFuture<TokenService.CreateTokenResult> future = new PlainActionFuture<>();
final String userTokenId = UUIDs.randomBase64UUID();
final String refreshToken = UUIDs.randomBase64UUID();
tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, tokenMetadata, future);
final String accessToken = future.actionGet().v1();
final String accessToken = future.actionGet().getAccessToken();
mockGetTokenFromId(tokenService, userTokenId, authentication, false, client);

final OpenIdConnectLogoutRequest request = new OpenIdConnectLogoutRequest();
Expand Down
Loading

0 comments on commit c23c057

Please sign in to comment.