Skip to content

Commit

Permalink
Amazon Bedrock - Fix system test (#38887)
Browse files Browse the repository at this point in the history
  • Loading branch information
ferruzzi authored Apr 10, 2024
1 parent 3d80435 commit b6ff085
Showing 1 changed file with 46 additions and 46 deletions.
92 changes: 46 additions & 46 deletions tests/system/providers/amazon/aws/example_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@

import json
from datetime import datetime
from os import environ

from botocore.exceptions import ClientError

from airflow.decorators import task
from airflow.decorators import task, task_group
from airflow.models.baseoperator import chain
from airflow.models.dag import DAG
from airflow.operators.empty import EmptyOperator
from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook
from airflow.providers.amazon.aws.operators.bedrock import (
BedrockCustomizeModelOperator,
Expand All @@ -35,6 +35,7 @@
S3DeleteBucketOperator,
)
from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor
from airflow.utils.edgemodifier import Label
from airflow.utils.trigger_rule import TriggerRule
from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder

Expand All @@ -44,10 +45,10 @@

DAG_ID = "example_bedrock"

# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS is True then set
# the trigger rule to an improbable state. This way we can still have the code snippets
# for docs, and we can manually run the full tests occasionally.
SKIP_LONG_TASKS = True
# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS
# is True then these tasks will be skipped. This way we can still have
# the code snippets for docs, and we can manually run the full tests.
SKIP_LONG_TASKS = environ.get("SKIP_LONG_SYSTEM_TEST_TASKS", default=True)

LLAMA_MODEL_ID = "meta.llama2-13b-chat-v1"
PROMPT = "What color is an orange?"
Expand All @@ -61,15 +62,41 @@
}


@task
def delete_custom_model(model_name: str):
try:
BedrockHook().conn.delete_custom_model(modelIdentifier=model_name)
except ClientError as e:
if SKIP_LONG_TASKS and (e.response["Error"]["Code"] == "ValidationException"):
# There is no model to delete. Since we skipped making one, that's fine.
return
raise e
@task_group
def customize_model_workflow():
# [START howto_operator_customize_model]
customize_model = BedrockCustomizeModelOperator(
task_id="customize_model",
job_name=custom_model_job_name,
custom_model_name=custom_model_name,
role_arn=test_context[ROLE_ARN_KEY],
base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}",
hyperparameters=HYPERPARAMETERS,
training_data_uri=training_data_uri,
output_data_uri=f"s3://{bucket_name}/myOutputData",
)
# [END howto_operator_customize_model]

# [START howto_sensor_customize_model]
await_custom_model_job = BedrockCustomizeModelCompletedSensor(
task_id="await_custom_model_job",
job_name=custom_model_job_name,
)
# [END howto_sensor_customize_model]

@task
def delete_custom_model():
BedrockHook().conn.delete_custom_model(modelIdentifier=custom_model_name)

@task.branch
def run_or_skip():
return end_workflow.task_id if SKIP_LONG_TASKS else customize_model.task_id

run_or_skip = run_or_skip()
end_workflow = EmptyOperator(task_id="end_workflow", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)

chain(run_or_skip, Label("Long-running tasks skipped"), end_workflow)
chain(run_or_skip, customize_model, await_custom_model_job, delete_custom_model(), end_workflow)


with DAG(
Expand All @@ -95,7 +122,7 @@ def delete_custom_model(model_name: str):
upload_training_data = S3CreateObjectOperator(
task_id="upload_data",
s3_bucket=bucket_name,
s3_key=training_data_uri,
s3_key=input_data_s3_key,
data=json.dumps(TRAIN_DATA),
)

Expand All @@ -115,30 +142,6 @@ def delete_custom_model(model_name: str):
)
# [END howto_operator_invoke_titan_model]

# [START howto_operator_customize_model]
customize_model = BedrockCustomizeModelOperator(
task_id="customize_model",
job_name=custom_model_job_name,
custom_model_name=custom_model_name,
role_arn=test_context[ROLE_ARN_KEY],
base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}",
hyperparameters=HYPERPARAMETERS,
training_data_uri=training_data_uri,
output_data_uri=f"s3://{bucket_name}/myOutputData",
)
# [END howto_operator_customize_model]

# [START howto_sensor_customize_model]
await_custom_model_job = BedrockCustomizeModelCompletedSensor(
task_id="await_custom_model_job",
job_name=custom_model_job_name,
)
# [END howto_sensor_customize_model]

if SKIP_LONG_TASKS:
customize_model.trigger_rule = TriggerRule.ALL_SKIPPED
await_custom_model_job.trigger_rule = TriggerRule.ALL_SKIPPED

delete_bucket = S3DeleteBucketOperator(
task_id="delete_bucket",
trigger_rule=TriggerRule.ALL_DONE,
Expand All @@ -152,12 +155,9 @@ def delete_custom_model(model_name: str):
create_bucket,
upload_training_data,
# TEST BODY
invoke_llama_model,
invoke_titan_model,
customize_model,
await_custom_model_job,
[invoke_llama_model, invoke_titan_model],
customize_model_workflow(),
# TEST TEARDOWN
delete_custom_model(custom_model_name),
delete_bucket,
)

Expand Down

0 comments on commit b6ff085

Please sign in to comment.