Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing a problem with potential dirty read of a token document on token refresh #64031

Merged
merged 5 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,10 @@ protected void doExecute(Task task, DelegatePkiAuthenticationRequest request,
ActionListener.wrap(authentication -> {
assert authentication != null : "authentication should never be null at this point";
tokenService.createOAuth2Tokens(authentication, delegateeAuthentication, Map.of(), false,
ActionListener.wrap(tuple -> {
ActionListener.wrap(tokenResult -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(new DelegatePkiAuthenticationResponse(tuple.v1(), expiresIn, authentication));
listener.onResponse(new DelegatePkiAuthenticationResponse(tokenResult.getAccessToken(), expiresIn,
authentication));
}, listener::onFailure));
}, e -> {
logger.debug((Supplier<?>) () -> new ParameterizedMessage("Delegated x509Token [{}] could not be authenticated",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ protected void doExecute(Task task, OpenIdConnectAuthenticateRequest request,
@SuppressWarnings("unchecked") final Map<String, Object> tokenMetadata = (Map<String, Object>) result.getMetadata()
.get(OpenIdConnectRealm.CONTEXT_TOKEN_DATA);
tokenService.createOAuth2Tokens(authentication, originatingAuthentication, tokenMetadata, true,
ActionListener.wrap(tuple -> {
ActionListener.wrap(tokenResult -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication, tuple.v1(), tuple.v2(), expiresIn));
listener.onResponse(new OpenIdConnectAuthenticateResponse(authentication, tokenResult.getAccessToken(),
tokenResult.getRefreshToken(), expiresIn));
}, listener::onFailure));
}, e -> {
logger.debug(() -> new ParameterizedMessage("OpenIDConnectToken [{}] could not be authenticated", token), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,11 @@ protected void doExecute(Task task, SamlAuthenticateRequest request, ActionListe
assert authentication != null : "authentication should never be null at this point";
final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA);
tokenService.createOAuth2Tokens(authentication, originatingAuthentication,
tokenMeta, true, ActionListener.wrap(tuple -> {
tokenMeta, true, ActionListener.wrap(tokenResult -> {
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(
new SamlAuthenticateResponse(authentication, tuple.v1(), tuple.v2(), expiresIn));
new SamlAuthenticateResponse(authentication, tokenResult.getAccessToken(), tokenResult.getRefreshToken(),
expiresIn));
}, listener::onFailure));
}, e -> {
logger.debug(() -> new ParameterizedMessage("SamlToken [{}] could not be authenticated", saml), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ private void clearCredentialsFromRequest(GrantType grantType, CreateTokenRequest
private void createToken(GrantType grantType, CreateTokenRequest request, Authentication authentication, Authentication originatingAuth,
boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) {
tokenService.createOAuth2Tokens(authentication, originatingAuth, Collections.emptyMap(), includeRefreshToken,
ActionListener.wrap(tuple -> {
ActionListener.wrap(tokenResult -> {
final String scope = getResponseScopeValue(request.getScope());
final String base64AuthenticateResponse = (grantType == GrantType.KERBEROS) ? extractOutToken() : null;
final CreateTokenResponse response = new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope,
tuple.v2(), base64AuthenticateResponse, authentication);
final CreateTokenResponse response = new CreateTokenResponse(tokenResult.getAccessToken(),
tokenService.getExpirationDelay(), scope, tokenResult.getRefreshToken(), base64AuthenticateResponse,
authentication);
listener.onResponse(response);
}, listener::onFailure));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest;
Expand All @@ -31,13 +30,12 @@ public TransportRefreshTokenAction(TransportService transportService, ActionFilt

@Override
protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tokenResult -> {
final String scope = getResponseScopeValue(request.getScope());
tokenService.authenticateToken(new SecureString(tuple.v1()), ActionListener.wrap(authentication -> {
listener.onResponse(new CreateTokenResponse(tuple.v1(), tokenService.getExpirationDelay(), scope, tuple.v2(), null,
authentication));
},
listener::onFailure));
final CreateTokenResponse response =
new CreateTokenResponse(tokenResult.getAccessToken(), tokenService.getExpirationDelay(), scope,
tokenResult.getRefreshToken(), null, tokenResult.getAuthentication());
BigPandaToo marked this conversation as resolved.
Show resolved Hide resolved
listener.onResponse(response);
}, listener::onFailure));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ public TokenService(Settings settings, Clock clock, Client client, XPackLicenseS
* {@link #VERSION_TOKENS_INDEX_INTRODUCED} and to a specific security tokens index for later versions.
*/
public void createOAuth2Tokens(Authentication authentication, Authentication originatingClientAuth, Map<String, Object> metadata,
boolean includeRefreshToken, ActionListener<Tuple<String, String>> listener) {
boolean includeRefreshToken, ActionListener<CreateTokenResult> listener) {
// the created token is compatible with the oldest node version in the cluster
final Version tokenVersion = getTokenVersionCompatibility();
// tokens moved to a separate index in newer versions
Expand All @@ -273,7 +273,7 @@ public void createOAuth2Tokens(Authentication authentication, Authentication ori
//public for testing
public void createOAuth2Tokens(String accessToken, String refreshToken, Authentication authentication,
Authentication originatingClientAuth,
Map<String, Object> metadata, ActionListener<Tuple<String, String>> listener) {
Map<String, Object> metadata, ActionListener<CreateTokenResult> listener) {
// the created token is compatible with the oldest node version in the cluster
final Version tokenVersion = getTokenVersionCompatibility();
// tokens moved to a separate index in newer versions
Expand Down Expand Up @@ -306,12 +306,13 @@ public void createOAuth2Tokens(String accessToken, String refreshToken, Authenti
* @param authentication The authentication object representing the user for which the tokens are created
* @param originatingClientAuth The authentication object representing the client that called the related API
* @param metadata A map with metadata to be stored in the token document
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* @param listener The listener to call upon completion with a {@link CreateTokenResult} containing the
* serialized access token, serialized refresh token and authentication for which the token is created
* as these will be returned to the client
*/
private void createOAuth2Tokens(String accessToken, String refreshToken, Version tokenVersion, SecurityIndexManager tokensIndex,
BigPandaToo marked this conversation as resolved.
Show resolved Hide resolved
Authentication authentication, Authentication originatingClientAuth, Map<String, Object> metadata,
ActionListener<Tuple<String, String>> listener) {
ActionListener<CreateTokenResult> listener) {
assert accessToken.length() == TOKEN_LENGTH : "We assume token ids have a fixed length for nodes of a certain version."
+ " When changing the token length, be careful that the inferences about its length still hold.";
ensureEnabled();
Expand Down Expand Up @@ -351,12 +352,13 @@ private void createOAuth2Tokens(String accessToken, String refreshToken, Version
final String versionedRefreshToken = refreshToken != null
? prependVersionAndEncodeRefreshToken(tokenVersion, refreshToken)
: null;
listener.onResponse(new Tuple<>(versionedAccessToken, versionedRefreshToken));
listener.onResponse(new CreateTokenResult(versionedAccessToken, versionedRefreshToken,
authentication));
} else {
// prior versions of the refresh token are not version-prepended, as nodes on those
// versions don't expect it.
// Such nodes might exist in a mixed cluster during a rolling upgrade.
listener.onResponse(new Tuple<>(versionedAccessToken, refreshToken));
listener.onResponse(new CreateTokenResult(versionedAccessToken, refreshToken,authentication));
}
} else {
listener.onFailure(traceLog("create token",
Expand Down Expand Up @@ -859,10 +861,11 @@ private void indexInvalidation(Collection<String> tokenIds, SecurityIndexManager
* Called by the transport action in order to start the process of refreshing a token.
*
* @param refreshToken The refresh token as provided by the client
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* @param listener The listener to call upon completion with a {@link CreateTokenResult} containing the
* serialized access token, serialized refresh token and authentication for which the token is created
* as these will be returned to the client
*/
public void refreshToken(String refreshToken, ActionListener<Tuple<String, String>> listener) {
public void refreshToken(String refreshToken, ActionListener<CreateTokenResult> listener) {
BigPandaToo marked this conversation as resolved.
Show resolved Hide resolved
ensureEnabled();
final Instant refreshRequested = clock.instant();
final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
Expand Down Expand Up @@ -995,7 +998,7 @@ private void findTokenFromRefreshToken(String refreshToken, SecurityIndexManager
*/
private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Object> source, long seqNo, long primaryTerm,
Authentication clientAuth, Iterator<TimeValue> backoff, Instant refreshRequested,
ActionListener<Tuple<String, String>> listener) {
ActionListener<CreateTokenResult> listener) {
logger.debug("Attempting to refresh token stored in token document [{}]", tokenDocId);
final Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex));
final Tuple<RefreshTokenStatus, Optional<ElasticsearchSecurityException>> checkRefreshResult;
Expand All @@ -1014,7 +1017,9 @@ private void innerRefresh(String refreshToken, String tokenDocId, Map<String, Ob
if (refreshTokenStatus.isRefreshed()) {
logger.debug("Token document [{}] was recently refreshed, when a new token document was generated. Reusing that result.",
tokenDocId);
decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, refreshedTokenIndex, listener);
final Tuple<UserToken, String> parsedTokens = parseTokensFromDocument(source, null);
Authentication authentication = parsedTokens.v1().getAuthentication();
decryptAndReturnSupersedingTokens(refreshToken, refreshTokenStatus, refreshedTokenIndex, authentication, listener);
} else {
final String newAccessTokenString = UUIDs.randomBase64UUID();
final String newRefreshTokenString = UUIDs.randomBase64UUID();
Expand Down Expand Up @@ -1126,11 +1131,13 @@ public void onFailure(Exception e) {
* @param refreshTokenStatus The {@link RefreshTokenStatus} containing information about the superseding tokens as retrieved from the
* index
* @param tokensIndex the manager for the index where the tokens are stored
* @param listener The listener to call upon completion with a {@link Tuple} containing the
* serialized access token and serialized refresh token as these will be returned to the client
* @param authentication The authentication object representing the user for which the tokens are created
* @param listener The listener to call upon completion with a {@link CreateTokenResult} containing the
* serialized access token, serialized refresh token and authentication for which the token is created
* as these will be returned to the client
*/
void decryptAndReturnSupersedingTokens(String refreshToken, RefreshTokenStatus refreshTokenStatus, SecurityIndexManager tokensIndex,
ActionListener<Tuple<String, String>> listener) {
Authentication authentication, ActionListener<CreateTokenResult> listener) {
BigPandaToo marked this conversation as resolved.
Show resolved Hide resolved

final byte[] iv = Base64.getDecoder().decode(refreshTokenStatus.getIv());
final byte[] salt = Base64.getDecoder().decode(refreshTokenStatus.getSalt());
Expand Down Expand Up @@ -1166,8 +1173,10 @@ public void onResponse(GetResponse response) {
if (response.isExists()) {
try {
listener.onResponse(
new Tuple<>(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(), decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1])));
new CreateTokenResult(prependVersionAndEncodeAccessToken(refreshTokenStatus.getVersion(),
decryptedTokens[0]),
prependVersionAndEncodeRefreshToken(refreshTokenStatus.getVersion(), decryptedTokens[1]),
authentication));
} catch (GeneralSecurityException | IOException e) {
logger.warn("Could not format stored superseding token values", e);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
Expand Down Expand Up @@ -1910,6 +1919,30 @@ boolean isExpirationInProgress() {
return expiredTokenRemover.isExpirationInProgress();
}

public static final class CreateTokenResult {
private final String accessToken;
private final String refreshToken;
private final Authentication authentication;

public CreateTokenResult(String accessToken, String refreshToken, Authentication authentication) {
this.accessToken = accessToken;
this.refreshToken = refreshToken;
this.authentication = authentication;
}

public String getAccessToken() {
return accessToken;
}

public String getRefreshToken() {
return refreshToken;
}

public Authentication getAuthentication() {
return authentication;
}
}

private class KeyComputingRunnable extends AbstractRunnable {

private final BytesKey decodedSalt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
Expand Down Expand Up @@ -203,11 +202,11 @@ public void testLogoutInvalidatesTokens() throws Exception {
final Authentication authentication = new Authentication(user, realmRef, null, null, Authentication.AuthenticationType.REALM,
tokenMetadata);

final PlainActionFuture<Tuple<String, String>> future = new PlainActionFuture<>();
final PlainActionFuture<TokenService.CreateTokenResult> future = new PlainActionFuture<>();
final String userTokenId = UUIDs.randomBase64UUID();
final String refreshToken = UUIDs.randomBase64UUID();
tokenService.createOAuth2Tokens(userTokenId, refreshToken, authentication, authentication, tokenMetadata, future);
final String accessToken = future.actionGet().v1();
final String accessToken = future.actionGet().getAccessToken();
mockGetTokenFromId(tokenService, userTokenId, authentication, false, client);

final OpenIdConnectLogoutRequest request = new OpenIdConnectLogoutRequest();
Expand Down
Loading