Skip to content

Commit

Permalink
Merge pull request #27538 from DrFaust92/sagemaker-workforce-vpc-config
Browse files Browse the repository at this point in the history
Sagemaker workforce - vpc support
  • Loading branch information
ewbankkit authored Oct 31, 2022
2 parents 9a808e1 + 0d500b2 commit 990f9ae
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .changelog/27538.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:enhancement
resource/aws_sagemaker_workforce: Add `workforce_vpc_config` argument
```
1 change: 1 addition & 0 deletions internal/service/ec2/sweep.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ func init() {
F: sweepVPCEndpoints,
Dependencies: []string{
"aws_route_table",
"aws_sagemaker_workforce",
},
})

Expand Down
1 change: 1 addition & 0 deletions internal/service/sagemaker/sagemaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func TestAccSageMaker_serial(t *testing.T) {
"CognitoConfig": testAccWorkforce_cognitoConfig,
"OidcConfig": testAccWorkforce_oidcConfig,
"SourceIpConfig": testAccWorkforce_sourceIPConfig,
"VPC": testAccWorkforce_vpc,
},
"Workteam": {
"disappears": testAccWorkteam_disappears,
Expand Down
16 changes: 16 additions & 0 deletions internal/service/sagemaker/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,3 +248,19 @@ func StatusProject(conn *sagemaker.SageMaker, name string) resource.StateRefresh
return output, aws.StringValue(output.ProjectStatus), nil
}
}

func StatusWorkforce(conn *sagemaker.SageMaker, name string) resource.StateRefreshFunc {
return func() (interface{}, string, error) {
output, err := FindWorkforceByName(conn, name)

if tfresource.NotFound(err) {
return nil, "", nil
}

if err != nil {
return nil, "", err
}

return output, aws.StringValue(output.Status), nil
}
}
44 changes: 44 additions & 0 deletions internal/service/sagemaker/wait.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ const (
FlowDefinitionDeletedTimeout = 2 * time.Minute
ProjectCreatedTimeout = 15 * time.Minute
ProjectDeletedTimeout = 15 * time.Minute
WorkforceActiveTimeout = 10 * time.Minute
WorkforceDeletedTimeout = 10 * time.Minute
)

// WaitNotebookInstanceInService waits for a NotebookInstance to return InService
Expand Down Expand Up @@ -508,3 +510,45 @@ func WaitProjectUpdated(conn *sagemaker.SageMaker, name string) (*sagemaker.Desc

return nil, err
}

func WaitWorkforceActive(conn *sagemaker.SageMaker, name string) (*sagemaker.Workforce, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{sagemaker.WorkforceStatusInitializing, sagemaker.WorkforceStatusUpdating},
Target: []string{sagemaker.WorkforceStatusActive},
Refresh: StatusWorkforce(conn, name),
Timeout: WorkforceActiveTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.Workforce); ok {
if status, reason := aws.StringValue(output.Status), aws.StringValue(output.FailureReason); status == sagemaker.WorkforceStatusFailed && reason != "" {
tfresource.SetLastError(err, errors.New(reason))
}

return output, err
}

return nil, err
}

func WaitWorkforceDeleted(conn *sagemaker.SageMaker, name string) (*sagemaker.Workforce, error) {
stateConf := &resource.StateChangeConf{
Pending: []string{sagemaker.WorkforceStatusDeleting},
Target: []string{},
Refresh: StatusWorkforce(conn, name),
Timeout: WorkforceDeletedTimeout,
}

outputRaw, err := stateConf.WaitForState()

if output, ok := outputRaw.(*sagemaker.Workforce); ok {
if status, reason := aws.StringValue(output.Status), aws.StringValue(output.FailureReason); status == sagemaker.WorkforceStatusFailed && reason != "" {
tfresource.SetLastError(err, errors.New(reason))
}

return output, err
}

return nil, err
}
86 changes: 84 additions & 2 deletions internal/service/sagemaker/workforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,35 @@ func ResourceWorkforce() *schema.Resource {
validation.StringMatch(regexp.MustCompile(`^[a-zA-Z0-9]([a-zA-Z0-9\-])*$`), "Valid characters are a-z, A-Z, 0-9, and - (hyphen)."),
),
},
"workforce_vpc_config": {
Type: schema.TypeList,
Optional: true,
MaxItems: 1,
Elem: &schema.Resource{
Schema: map[string]*schema.Schema{
"security_group_ids": {
Type: schema.TypeSet,
Optional: true,
MaxItems: 5,
Elem: &schema.Schema{Type: schema.TypeString},
},
"subnets": {
Type: schema.TypeSet,
Optional: true,
MaxItems: 16,
Elem: &schema.Schema{Type: schema.TypeString},
},
"vpc_endpoint_id": {
Type: schema.TypeString,
Computed: true,
},
"vpc_id": {
Type: schema.TypeString,
Optional: true,
},
},
},
},
},
}
}
Expand All @@ -170,7 +199,10 @@ func resourceWorkforceCreate(d *schema.ResourceData, meta interface{}) error {
input.SourceIpConfig = expandWorkforceSourceIPConfig(v.([]interface{}))
}

log.Printf("[DEBUG] Creating SageMaker Workforce: %s", input)
if v, ok := d.GetOk("workforce_vpc_config"); ok {
input.WorkforceVpcConfig = expandWorkforceVPCConfig(v.([]interface{}))
}

_, err := conn.CreateWorkforce(input)

if err != nil {
Expand All @@ -179,6 +211,10 @@ func resourceWorkforceCreate(d *schema.ResourceData, meta interface{}) error {

d.SetId(name)

if _, err := WaitWorkforceActive(conn, name); err != nil {
return fmt.Errorf("waiting for SageMaker Workforce (%s) create: %w", d.Id(), err)
}

return resourceWorkforceRead(d, meta)
}

Expand Down Expand Up @@ -215,6 +251,10 @@ func resourceWorkforceRead(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("setting source_ip_config: %w", err)
}

if err := d.Set("workforce_vpc_config", flattenWorkforceVPCConfig(workforce.WorkforceVpcConfig)); err != nil {
return fmt.Errorf("setting workforce_vpc_config: %w", err)
}

return nil
}

Expand All @@ -233,13 +273,20 @@ func resourceWorkforceUpdate(d *schema.ResourceData, meta interface{}) error {
input.OidcConfig = expandWorkforceOIDCConfig(d.Get("oidc_config").([]interface{}))
}

log.Printf("[DEBUG] Updating SageMaker Workforce: %s", input)
if d.HasChange("workforce_vpc_config") {
input.WorkforceVpcConfig = expandWorkforceVPCConfig(d.Get("workforce_vpc_config").([]interface{}))
}

_, err := conn.UpdateWorkforce(input)

if err != nil {
return fmt.Errorf("updating SageMaker Workforce (%s): %w", d.Id(), err)
}

if _, err := WaitWorkforceActive(conn, d.Id()); err != nil {
return fmt.Errorf("waiting for SageMaker Workforce (%s) update: %w", d.Id(), err)
}

return resourceWorkforceRead(d, meta)
}

Expand All @@ -259,6 +306,10 @@ func resourceWorkforceDelete(d *schema.ResourceData, meta interface{}) error {
return fmt.Errorf("deleting SageMaker Workforce (%s): %w", d.Id(), err)
}

if _, err := WaitWorkforceDeleted(conn, d.Id()); err != nil {
return fmt.Errorf("waiting for SageMaker Workforce (%s) delete: %w", d.Id(), err)
}

return nil
}

Expand Down Expand Up @@ -355,3 +406,34 @@ func flattenWorkforceOIDCConfig(config *sagemaker.OidcConfigForResponse, clientS

return []map[string]interface{}{m}
}

func expandWorkforceVPCConfig(l []interface{}) *sagemaker.WorkforceVpcConfigRequest {
if len(l) == 0 || l[0] == nil {
return &sagemaker.WorkforceVpcConfigRequest{}
}

m := l[0].(map[string]interface{})

config := &sagemaker.WorkforceVpcConfigRequest{
SecurityGroupIds: flex.ExpandStringSet(m["security_group_ids"].(*schema.Set)),
Subnets: flex.ExpandStringSet(m["subnets"].(*schema.Set)),
VpcId: aws.String(m["vpc_id"].(string)),
}

return config
}

func flattenWorkforceVPCConfig(config *sagemaker.WorkforceVpcConfigResponse) []map[string]interface{} {
if config == nil {
return []map[string]interface{}{}
}

m := map[string]interface{}{
"security_group_ids": flex.FlattenStringSet(config.SecurityGroupIds),
"subnets": flex.FlattenStringSet(config.Subnets),
"vpc_endpoint_id": aws.StringValue(config.VpcEndpointId),
"vpc_id": aws.StringValue(config.VpcId),
}

return []map[string]interface{}{m}
}
Loading

0 comments on commit 990f9ae

Please sign in to comment.