From b6ff085679c283cd3ccc3edf20dd3e6b0eaec967 Mon Sep 17 00:00:00 2001 From: "D. Ferruzzi" Date: Wed, 10 Apr 2024 10:40:34 -0700 Subject: [PATCH] Amazon Bedrock - Fix system test (#38887) --- .../providers/amazon/aws/example_bedrock.py | 92 +++++++++---------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py index 12e24615473..e25bbb8ed77 100644 --- a/tests/system/providers/amazon/aws/example_bedrock.py +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -18,12 +18,12 @@ import json from datetime import datetime +from os import environ -from botocore.exceptions import ClientError - -from airflow.decorators import task +from airflow.decorators import task, task_group from airflow.models.baseoperator import chain from airflow.models.dag import DAG +from airflow.operators.empty import EmptyOperator from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook from airflow.providers.amazon.aws.operators.bedrock import ( BedrockCustomizeModelOperator, @@ -35,6 +35,7 @@ S3DeleteBucketOperator, ) from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor +from airflow.utils.edgemodifier import Label from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder @@ -44,10 +45,10 @@ DAG_ID = "example_bedrock" -# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS is True then set -# the trigger rule to an improbable state. This way we can still have the code snippets -# for docs, and we can manually run the full tests occasionally. -SKIP_LONG_TASKS = True +# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS +# is True then these tasks will be skipped. This way we can still have +# the code snippets for docs, and we can manually run the full tests. +SKIP_LONG_TASKS = environ.get("SKIP_LONG_SYSTEM_TEST_TASKS", default=True) LLAMA_MODEL_ID = "meta.llama2-13b-chat-v1" PROMPT = "What color is an orange?" @@ -61,15 +62,41 @@ } -@task -def delete_custom_model(model_name: str): - try: - BedrockHook().conn.delete_custom_model(modelIdentifier=model_name) - except ClientError as e: - if SKIP_LONG_TASKS and (e.response["Error"]["Code"] == "ValidationException"): - # There is no model to delete. Since we skipped making one, that's fine. - return - raise e +@task_group +def customize_model_workflow(): + # [START howto_operator_customize_model] + customize_model = BedrockCustomizeModelOperator( + task_id="customize_model", + job_name=custom_model_job_name, + custom_model_name=custom_model_name, + role_arn=test_context[ROLE_ARN_KEY], + base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}", + hyperparameters=HYPERPARAMETERS, + training_data_uri=training_data_uri, + output_data_uri=f"s3://{bucket_name}/myOutputData", + ) + # [END howto_operator_customize_model] + + # [START howto_sensor_customize_model] + await_custom_model_job = BedrockCustomizeModelCompletedSensor( + task_id="await_custom_model_job", + job_name=custom_model_job_name, + ) + # [END howto_sensor_customize_model] + + @task + def delete_custom_model(): + BedrockHook().conn.delete_custom_model(modelIdentifier=custom_model_name) + + @task.branch + def run_or_skip(): + return end_workflow.task_id if SKIP_LONG_TASKS else customize_model.task_id + + run_or_skip = run_or_skip() + end_workflow = EmptyOperator(task_id="end_workflow", trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) + + chain(run_or_skip, Label("Long-running tasks skipped"), end_workflow) + chain(run_or_skip, customize_model, await_custom_model_job, delete_custom_model(), end_workflow) with DAG( @@ -95,7 +122,7 @@ def delete_custom_model(model_name: str): upload_training_data = S3CreateObjectOperator( task_id="upload_data", s3_bucket=bucket_name, - s3_key=training_data_uri, + s3_key=input_data_s3_key, data=json.dumps(TRAIN_DATA), ) @@ -115,30 +142,6 @@ def delete_custom_model(model_name: str): ) # [END howto_operator_invoke_titan_model] - # [START howto_operator_customize_model] - customize_model = BedrockCustomizeModelOperator( - task_id="customize_model", - job_name=custom_model_job_name, - custom_model_name=custom_model_name, - role_arn=test_context[ROLE_ARN_KEY], - base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}", - hyperparameters=HYPERPARAMETERS, - training_data_uri=training_data_uri, - output_data_uri=f"s3://{bucket_name}/myOutputData", - ) - # [END howto_operator_customize_model] - - # [START howto_sensor_customize_model] - await_custom_model_job = BedrockCustomizeModelCompletedSensor( - task_id="await_custom_model_job", - job_name=custom_model_job_name, - ) - # [END howto_sensor_customize_model] - - if SKIP_LONG_TASKS: - customize_model.trigger_rule = TriggerRule.ALL_SKIPPED - await_custom_model_job.trigger_rule = TriggerRule.ALL_SKIPPED - delete_bucket = S3DeleteBucketOperator( task_id="delete_bucket", trigger_rule=TriggerRule.ALL_DONE, @@ -152,12 +155,9 @@ def delete_custom_model(model_name: str): create_bucket, upload_training_data, # TEST BODY - invoke_llama_model, - invoke_titan_model, - customize_model, - await_custom_model_job, + [invoke_llama_model, invoke_titan_model], + customize_model_workflow(), # TEST TEARDOWN - delete_custom_model(custom_model_name), delete_bucket, )