Skip to content

Commit

Permalink
sagemaker workforce vpc support
Browse files Browse the repository at this point in the history
  • Loading branch information
DrFaust92 committed Oct 28, 2022
1 parent 32955d8 commit e1951bc
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 3 deletions.
1 change: 1 addition & 0 deletions internal/service/ec2/sweep.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ func init() {
F: sweepVPCEndpoints,
Dependencies: []string{
"aws_route_table",
"aws_sagemaker_workforce",
},
})

Expand Down
1 change: 1 addition & 0 deletions internal/service/sagemaker/sagemaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func TestAccSageMaker_serial(t *testing.T) {
"CognitoConfig": testAccWorkforce_cognitoConfig,
"OidcConfig": testAccWorkforce_oidcConfig,
"SourceIpConfig": testAccWorkforce_sourceIPConfig,
"VPC": testAccWorkforce_vpc,
},
"Workteam": {
"disappears": testAccWorkteam_disappears,
Expand Down
16 changes: 16 additions & 0 deletions internal/service/sagemaker/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,19 @@ func StatusProject(conn *sagemaker.SageMaker, name string) resource.StateRefresh
return output, aws.StringValue(output.ProjectStatus), nil
}
}

func StatusWorkforce(conn *sagemaker.SageMaker, name string) resource.StateRefreshFunc {
return func() (interface{}, string, error) {
output, err := FindWorkforceByName(conn, name)

if tfresource.NotFound(err) {
return nil, "", nil
}

if err != nil {
return nil, "", err
}

return output, aws.StringValue(output.Status), nil
}
}
44 changes: 44 additions & 0 deletions internal/service/sagemaker/wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ const (
FlowDefinitionDeletedTimeout = 2 * time.Minute
ProjectCreatedTimeout = 15 * time.Minute
ProjectDeletedTimeout = 15 * time.Minute
WorkforceActiveTimeout = 10 * time.Minute
WorkforceDeletedTimeout = 10 * time.Minute
)

// WaitNotebookInstanceInService waits for a NotebookInstance to return InService
Expand Down Expand Up @@ -508,3 +510,45 @@ func WaitProjectUpdated(conn *sagemaker.SageMaker, name string) (*sagemaker.Desc

return nil, err
}

func WaitWorkforceActive(conn *sagemaker.SageMaker, name string) (*sagemaker.Workforce, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{sagemaker.WorkforceStatusInitializing, sagemaker.WorkforceStatusUpdating},
Target: []string{sagemaker.WorkforceStatusActive},
Refresh: StatusWorkforce(conn, name),
Timeout: WorkforceActiveTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.Workforce); ok {
if status, reason := aws.StringValue(output.Status), aws.StringValue(output.FailureReason); status == sagemaker.WorkforceStatusFailed && reason != "" {
tfresource.SetLastError(err, errors.New(reason))
}

return output, err
}

return nil, err
}

func WaitWorkforceDeleted(conn *sagemaker.SageMaker, name string) (*sagemaker.Workforce, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{sagemaker.WorkforceStatusDeleting},
Target: []string{},
Refresh: StatusWorkforce(conn, name),
Timeout: WorkforceDeletedTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.Workforce); ok {
if status, reason := aws.StringValue(output.Status), aws.StringValue(output.FailureReason); status == sagemaker.WorkforceStatusFailed && reason != "" {
tfresource.SetLastError(err, errors.New(reason))
}

return output, err
}

return nil, err
}
84 changes: 84 additions & 0 deletions internal/service/sagemaker/workforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,35 @@ func ResourceWorkforce() *schema.Resource {
validation.StringMatch(regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9\-])*$`), "Valid characters are a-z, A-Z, 0-9, and - (hyphen)."),
),
},
"workforce_vpc_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"security_group_ids": {
Type: schema.TypeSet,
Optional: true,
MaxItems: 5,
Elem: &schema.Schema{Type: schema.TypeString},
},
"subnets": {
Type: schema.TypeSet,
Optional: true,
MaxItems: 16,
Elem: &schema.Schema{Type: schema.TypeString},
},
"vpc_endpoint_id": {
Type: schema.TypeString,
Computed: true,
},
"vpc_id": {
Type: schema.TypeString,
Optional: true,
},
},
},
},
},
}
}
Expand All @@ -170,6 +199,10 @@ func resourceWorkforceCreate(d *schema.ResourceData, meta interface{}) error {
input.SourceIpConfig = expandWorkforceSourceIPConfig(v.([]interface{}))
}

if v, ok := d.GetOk("workforce_vpc_config"); ok {
input.WorkforceVpcConfig = expandWorkforceVpcConfig(v.([]interface{}))
}

log.Printf("[DEBUG] Creating SageMaker Workforce: %s", input)
_, err := conn.CreateWorkforce(input)

Expand All @@ -179,6 +212,10 @@ func resourceWorkforceCreate(d *schema.ResourceData, meta interface{}) error {

d.SetId(name)

if _, err := WaitWorkforceActive(conn, name); err != nil {
return fmt.Errorf("waiting for SageMaker Workfoce (%s) to be created: %w", d.Id(), err)
}

return resourceWorkforceRead(d, meta)
}

Expand Down Expand Up @@ -215,6 +252,10 @@ func resourceWorkforceRead(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("setting source_ip_config: %w", err)
}

if err := d.Set("workforce_vpc_config", flattenWorkforceVpcConfig(workforce.WorkforceVpcConfig)); err != nil {
return fmt.Errorf("setting workforce_vpc_config: %w", err)
}

return nil
}

Expand All @@ -233,13 +274,21 @@ func resourceWorkforceUpdate(d *schema.ResourceData, meta interface{}) error {
input.OidcConfig = expandWorkforceOIDCConfig(d.Get("oidc_config").([]interface{}))
}

if d.HasChange("workforce_vpc_config") {
input.WorkforceVpcConfig = expandWorkforceVpcConfig(d.Get("workforce_vpc_config").([]interface{}))
}

log.Printf("[DEBUG] Updating SageMaker Workforce: %s", input)
_, err := conn.UpdateWorkforce(input)

if err != nil {
return fmt.Errorf("updating SageMaker Workforce (%s): %w", d.Id(), err)
}

if _, err := WaitWorkforceActive(conn, d.Id()); err != nil {
return fmt.Errorf("waiting for SageMaker Workfoce (%s) to be updated: %w", d.Id(), err)
}

return resourceWorkforceRead(d, meta)
}

Expand All @@ -259,6 +308,10 @@ func resourceWorkforceDelete(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("deleting SageMaker Workforce (%s): %w", d.Id(), err)
}

if _, err := WaitWorkforceDeleted(conn, d.Id()); err != nil {
return fmt.Errorf("waiting for SageMaker Workfoce (%s) to be deleted: %w", d.Id(), err)
}

return nil
}

Expand Down Expand Up @@ -355,3 +408,34 @@ func flattenWorkforceOIDCConfig(config *sagemaker.OidcConfigForResponse, clientS

return []map[string]interface{}{m}
}

func expandWorkforceVpcConfig(l []interface{}) *sagemaker.WorkforceVpcConfigRequest {
if len(l) == 0 || l[0] == nil {
return &sagemaker.WorkforceVpcConfigRequest{}
}

m := l[0].(map[string]interface{})

config := &sagemaker.WorkforceVpcConfigRequest{
SecurityGroupIds: flex.ExpandStringSet(m["security_group_ids"].(*schema.Set)),
Subnets: flex.ExpandStringSet(m["subnets"].(*schema.Set)),
VpcId: aws.String(m["vpc_id"].(string)),
}

return config
}

func flattenWorkforceVpcConfig(config *sagemaker.WorkforceVpcConfigResponse) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
}

m := map[string]interface{}{
"security_group_ids": flex.FlattenStringSet(config.SecurityGroupIds),
"subnets": flex.FlattenStringSet(config.Subnets),
"vpc_endpoint_id": aws.StringValue(config.VpcEndpointId),
"vpc_id": aws.StringValue(config.VpcId),
}

return []map[string]interface{}{m}
}
128 changes: 128 additions & 0 deletions internal/service/sagemaker/workforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ func testAccWorkforce_cognitoConfig(t *testing.T) {
resource.TestCheckResourceAttr(resourceName, "source_ip_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "source_ip_config.0.cidrs.#", "0"),
resource.TestCheckResourceAttrSet(resourceName, "subdomain"),
resource.TestCheckResourceAttr(resourceName, "workforce_vpc_config.#", "0"),
),
},
{
Expand Down Expand Up @@ -114,6 +115,7 @@ func testAccWorkforce_oidcConfig(t *testing.T) {
},
})
}

func testAccWorkforce_sourceIPConfig(t *testing.T) {
var workforce sagemaker.Workforce
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -161,6 +163,42 @@ func testAccWorkforce_sourceIPConfig(t *testing.T) {
})
}

func testAccWorkforce_vpc(t *testing.T) {
var workforce sagemaker.Workforce
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
resourceName := "aws_sagemaker_workforce.test"

resource.Test(t, resource.TestCase{
PreCheck: func() { acctest.PreCheck(t) },
ErrorCheck: acctest.ErrorCheck(t, sagemaker.EndpointsID),
ProtoV5ProviderFactories: acctest.ProtoV5ProviderFactories,
CheckDestroy: testAccCheckWorkforceDestroy,
Steps: []resource.TestStep{
{
Config: testAccWorkforceConfig_vpc(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckWorkforceExists(resourceName, &workforce),
resource.TestCheckResourceAttr(resourceName, "workforce_vpc_config.#", "1"),
resource.TestCheckResourceAttr(resourceName, "workforce_vpc_config.0.security_group_ids.#", "1"),
resource.TestCheckResourceAttr(resourceName, "workforce_vpc_config.0.subnets.#", "1"),
),
},
{
ResourceName: resourceName,
ImportState: true,
ImportStateVerify: true,
},
{
Config: testAccWorkforceConfig_vpcRemove(rName),
Check: resource.ComposeTestCheckFunc(
testAccCheckWorkforceExists(resourceName, &workforce),
resource.TestCheckResourceAttr(resourceName, "workforce_vpc_config.#", "0"),
),
},
},
})
}

func testAccWorkforce_disappears(t *testing.T) {
var workforce sagemaker.Workforce
rName := sdkacctest.RandomWithPrefix(acctest.ResourcePrefix)
Expand Down Expand Up @@ -317,3 +355,93 @@ resource "aws_sagemaker_workforce" "test" {
}
`, rName, endpoint)
}

func testAccWorkforceConfig_vpc(rName string) string {
return acctest.ConfigCompose(testAccWorkforceBaseConfig(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(`
resource "aws_vpc" "test" {
cidr_block = "10.1.0.0/16"
enable_dns_hostnames = true
tags = {
Name = %[1]q
}
}
resource "aws_subnet" "test" {
cidr_block = "10.1.1.0/24"
availability_zone = data.aws_availability_zones.available.names[0]
vpc_id = aws_vpc.test.id
tags = {
Name = %[1]q
}
}
resource "aws_security_group" "test" {
name = %[1]q
vpc_id = aws_vpc.test.id
tags = {
Name = %[1]q
}
}
resource "aws_sagemaker_workforce" "test" {
workforce_name = %[1]q
cognito_config {
client_id = aws_cognito_user_pool_client.test.id
user_pool = aws_cognito_user_pool_domain.test.user_pool_id
}
workforce_vpc_config {
security_group_ids = aws_security_group.test.*.id
subnets = [aws_subnet.test.id]
vpc_id = aws_vpc.test.id
}
}
`, rName))
}

func testAccWorkforceConfig_vpcRemove(rName string) string {
return acctest.ConfigCompose(testAccWorkforceBaseConfig(rName), acctest.ConfigAvailableAZsNoOptIn(), fmt.Sprintf(`
resource "aws_vpc" "test" {
cidr_block = "10.1.0.0/16"
enable_dns_hostnames = true
tags = {
Name = %[1]q
}
}
resource "aws_subnet" "test" {
cidr_block = "10.1.1.0/24"
availability_zone = data.aws_availability_zones.available.names[0]
vpc_id = aws_vpc.test.id
tags = {
Name = %[1]q
}
}
resource "aws_security_group" "test" {
count = 1
name = "%[1]s-${count.index}"
vpc_id = aws_vpc.test.id
tags = {
Name = %[1]q
}
}
resource "aws_sagemaker_workforce" "test" {
workforce_name = %[1]q
cognito_config {
client_id = aws_cognito_user_pool_client.test.id
user_pool = aws_cognito_user_pool_domain.test.user_pool_id
}
}
`, rName))
}
Loading

0 comments on commit e1951bc

Please sign in to comment.