Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for kms key aliases #1537

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 123 additions & 8 deletions kms/keysource.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,41 @@ func ParseKMSContext(in interface{}) map[string]*string {
return out
}

// GetArnByAlias takes a AWS KMS Client, key.
// When argument comes with alias, convert it to arn
func GetArnByAlias(client *kms.Client, key *MasterKey) error {
input := &kms.ListAliasesInput{}

paginator := kms.NewListAliasesPaginator(client, input)
found := false
for paginator.HasMorePages() && !found {
output, err := paginator.NextPage(context.Background())
if err != nil {
return fmt.Errorf("failed to get kms key: %w", err)
}

for _, alias := range output.Aliases {
if strings.HasSuffix(*alias.AliasArn, key.Arn) {

describeInput := &kms.DescribeKeyInput{
KeyId: aws.String(*alias.TargetKeyId),
}

describeOutput, err := client.DescribeKey(context.Background(), describeInput)

if err != nil {
return fmt.Errorf("failed to describe key: %w", err)
}

key.Arn = *describeOutput.KeyMetadata.Arn
found = true
}
}
}

return nil
}

// CredentialsProvider is a wrapper around aws.CredentialsProvider used for
// authentication towards AWS KMS.
type CredentialsProvider struct {
Expand Down Expand Up @@ -209,6 +244,31 @@ func (key *MasterKey) Encrypt(dataKey []byte) error {
return err
}
client := key.createClient(cfg)

// condition that input is an alias
if !strings.HasPrefix(string(key.Arn), "arn:aws:kms") {

err = GetArnByAlias(client, key)

if err != nil {
return err
}

encryptInput := &kms.EncryptInput{
KeyId: &key.Arn,
Plaintext: dataKey,
EncryptionContext: stringPointerToStringMap(key.EncryptionContext),
}
out, err := client.Encrypt(context.TODO(), encryptInput)
if err != nil {
log.WithField("arn", key.Arn).Info("Encryption failed")
return fmt.Errorf("failed to encrypt sops data key with AWS KMS 222: %s", key.Arn)
}
key.EncryptedKey = base64.StdEncoding.EncodeToString(out.CiphertextBlob)
log.WithField("arn", key.Arn).Info("Encryption succeeded")
return nil
}

input := &kms.EncryptInput{
KeyId: &key.Arn,
Plaintext: dataKey,
Expand Down Expand Up @@ -257,6 +317,52 @@ func (key *MasterKey) Decrypt() ([]byte, error) {
return nil, err
}
client := key.createClient(cfg)

if !strings.HasPrefix(string(key.Arn), "arn:aws:kms") {
input := &kms.ListAliasesInput{}

paginator := kms.NewListAliasesPaginator(client, input)
found := false
for paginator.HasMorePages() && !found {
output, err := paginator.NextPage(context.Background())
if err != nil {
log.WithField("arn", key.Arn).Info("Error listing aliases")
break
}

for _, alias := range output.Aliases {
if strings.HasSuffix(*alias.AliasArn, key.Arn) {

describeInput := &kms.DescribeKeyInput{
KeyId: aws.String(*alias.TargetKeyId),
}

describeOutput, err := client.DescribeKey(context.Background(), describeInput)

if err != nil {
return nil, fmt.Errorf("failed to describe key: %w", err)
}

key.Arn = *describeOutput.KeyMetadata.Arn
found = true
}
}
}

decryptedInput := &kms.DecryptInput{
KeyId: &key.Arn,
CiphertextBlob: k,
EncryptionContext: stringPointerToStringMap(key.EncryptionContext),
}
decrypted, err := client.Decrypt(context.TODO(), decryptedInput)
if err != nil {
log.WithField("arn", key.Arn).Info("Decryption failed")
return nil, fmt.Errorf("failed to decrypt sops data key with AWS KMS: %w", err)
}
log.WithField("arn", key.Arn).Info("Decryption succeeded")
return decrypted.Plaintext, nil
}

input := &kms.DecryptInput{
KeyId: &key.Arn,
CiphertextBlob: k,
Expand Down Expand Up @@ -307,13 +413,7 @@ func (key *MasterKey) TypeToIdentifier() string {

// createKMSConfig returns an AWS config with the credentialsProvider of the
// MasterKey, or the default configuration sources.
func (key MasterKey) createKMSConfig() (*aws.Config, error) {
re := regexp.MustCompile(arnRegex)
matches := re.FindStringSubmatch(key.Arn)
if matches == nil {
return nil, fmt.Errorf("no valid ARN found in '%s'", key.Arn)
}
region := matches[1]
func (key *MasterKey) createKMSConfig() (*aws.Config, error) {

cfg, err := config.LoadDefaultConfig(context.TODO(), func(lo *config.LoadOptions) error {
// Use the credentialsProvider if present, otherwise default to reading credentials
Expand All @@ -324,13 +424,28 @@ func (key MasterKey) createKMSConfig() (*aws.Config, error) {
if key.AwsProfile != "" {
lo.SharedConfigProfile = key.AwsProfile
}
lo.Region = region
return nil
})

if err != nil {
return nil, fmt.Errorf("could not load AWS config: %w", err)
}

re := regexp.MustCompile(arnRegex)
matches := re.FindStringSubmatch(key.Arn)

if matches == nil {
client := key.createClient(&cfg)
err = GetArnByAlias(client, key)

if err != nil {
return nil, err
}

matches = re.FindStringSubmatch(key.Arn)
cfg.Region = matches[1]
}

if key.Role != "" {
return key.createSTSConfig(&cfg)
}
Expand Down