Skip to content

Commit

Permalink
feat(ingestion/SageMaker): Remove deprecated apis and add stateful in…
Browse files Browse the repository at this point in the history
…gestion capability (datahub-project#10573)
  • Loading branch information
TonyOuyangGit authored and sleeperdeep committed Jun 25, 2024
1 parent 8096904 commit 7869c0d
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 362 deletions.
22 changes: 18 additions & 4 deletions metadata-ingestion/src/datahub/ingestion/source/aws/sagemaker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import DefaultDict, Dict, Iterable
from typing import DefaultDict, Dict, Iterable, List, Optional

from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
Expand All @@ -10,7 +10,7 @@
platform_name,
support_status,
)
from datahub.ingestion.api.source import Source
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.sagemaker_processors.common import (
SagemakerSourceConfig,
Expand All @@ -26,13 +26,19 @@
)
from datahub.ingestion.source.aws.sagemaker_processors.lineage import LineageProcessor
from datahub.ingestion.source.aws.sagemaker_processors.models import ModelProcessor
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalHandler,
)
from datahub.ingestion.source.state.stateful_ingestion_base import (
StatefulIngestionSourceBase,
)


@platform_name("SageMaker")
@config_class(SagemakerSourceConfig)
@support_status(SupportStatus.CERTIFIED)
@capability(SourceCapability.LINEAGE_COARSE, "Enabled by default")
class SagemakerSource(Source):
class SagemakerSource(StatefulIngestionSourceBase):
"""
This plugin extracts the following:
Expand All @@ -45,7 +51,7 @@ class SagemakerSource(Source):
report = SagemakerSourceReport()

def __init__(self, config: SagemakerSourceConfig, ctx: PipelineContext):
super().__init__(ctx)
super().__init__(config, ctx)
self.source_config = config
self.report = SagemakerSourceReport()
self.sagemaker_client = config.sagemaker_client
Expand All @@ -56,6 +62,14 @@ def create(cls, config_dict, ctx):
config = SagemakerSourceConfig.parse_obj(config_dict)
return cls(config, ctx)

def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
return [
*super().get_workunit_processors(),
StaleEntityRemovalHandler.create(
self, self.source_config, self.ctx
).workunit_processor,
]

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
# get common lineage graph
lineage_processor = LineageProcessor(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union

from pydantic.fields import Field

from datahub.ingestion.api.source import SourceReport
from datahub.configuration.source_common import PlatformInstanceConfigMixin
from datahub.ingestion.source.aws.aws_common import AwsSourceConfig
from datahub.ingestion.source.state.stale_entity_removal_handler import (
StaleEntityRemovalSourceReport,
StatefulIngestionConfigBase,
StatefulStaleMetadataRemovalConfig,
)


class SagemakerSourceConfig(AwsSourceConfig):
class SagemakerSourceConfig(
AwsSourceConfig,
PlatformInstanceConfigMixin,
StatefulIngestionConfigBase,
):
extract_feature_groups: Optional[bool] = Field(
default=True, description="Whether to extract feature groups."
)
Expand All @@ -17,21 +26,24 @@ class SagemakerSourceConfig(AwsSourceConfig):
extract_jobs: Optional[Union[Dict[str, str], bool]] = Field(
default=True, description="Whether to extract AutoML jobs."
)
# Custom Stateful Ingestion settings
stateful_ingestion: Optional[StatefulStaleMetadataRemovalConfig] = None

@property
def sagemaker_client(self):
return self.get_sagemaker_client()


@dataclass
class SagemakerSourceReport(SourceReport):
class SagemakerSourceReport(StaleEntityRemovalSourceReport):
feature_groups_scanned = 0
features_scanned = 0
endpoints_scanned = 0
groups_scanned = 0
models_scanned = 0
jobs_scanned = 0
datasets_scanned = 0
filtered: List[str] = field(default_factory=list)

def report_feature_group_scanned(self) -> None:
self.feature_groups_scanned += 1
Expand All @@ -53,3 +65,6 @@ def report_job_scanned(self) -> None:

def report_dataset_scanned(self) -> None:
self.datasets_scanned += 1

def report_dropped(self, name: str) -> None:
self.filtered.append(name)
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Dict, Final
from typing import Dict

from typing_extensions import Final

from datahub.metadata.schema_classes import JobStatusClass

Expand Down Expand Up @@ -63,8 +65,9 @@ class AutoMlJobInfo(SageMakerJobInfo):
list_key: Final = "AutoMLJobSummaries"
list_name_key: Final = "AutoMLJobName"
list_arn_key: Final = "AutoMLJobArn"
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
describe_command: Final = "describe_auto_ml_job"
# DescribeAutoMLJobV2 are new versions of CreateAutoMLJob and DescribeAutoMLJob which offer backward compatibility.
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_DescribeAutoMLJobV2.html
describe_command: Final = "describe_auto_ml_job_v2"
describe_name_key: Final = "AutoMLJobName"
describe_arn_key: Final = "AutoMLJobArn"
describe_status_key: Final = "AutoMLJobStatus"
Expand Down Expand Up @@ -101,28 +104,6 @@ class CompilationJobInfo(SageMakerJobInfo):
processor = "process_compilation_job"


class EdgePackagingJobInfo(SageMakerJobInfo):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_edge_packaging_jobs
list_command: Final = "list_edge_packaging_jobs"
list_key: Final = "EdgePackagingJobSummaries"
list_name_key: Final = "EdgePackagingJobName"
list_arn_key: Final = "EdgePackagingJobArn"
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_edge_packaging_job
describe_command: Final = "describe_edge_packaging_job"
describe_name_key: Final = "EdgePackagingJobName"
describe_arn_key: Final = "EdgePackagingJobArn"
describe_status_key: Final = "EdgePackagingJobStatus"
status_map = {
"INPROGRESS": JobStatusClass.IN_PROGRESS,
"COMPLETED": JobStatusClass.COMPLETED,
"FAILED": JobStatusClass.FAILED,
"STARTING": JobStatusClass.STARTING,
"STOPPING": JobStatusClass.STOPPING,
"STOPPED": JobStatusClass.STOPPED,
}
processor = "process_edge_packaging_job"


class HyperParameterTuningJobInfo(SageMakerJobInfo):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.list_hyper_parameter_tuning_jobs
list_command: Final = "list_hyper_parameter_tuning_jobs"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from datahub.ingestion.source.aws.sagemaker_processors.job_classes import (
AutoMlJobInfo,
CompilationJobInfo,
EdgePackagingJobInfo,
HyperParameterTuningJobInfo,
LabelingJobInfo,
ProcessingJobInfo,
Expand All @@ -53,7 +52,6 @@
"JobInfo",
AutoMlJobInfo,
CompilationJobInfo,
EdgePackagingJobInfo,
HyperParameterTuningJobInfo,
LabelingJobInfo,
ProcessingJobInfo,
Expand All @@ -65,7 +63,6 @@
class JobType(Enum):
AUTO_ML = "auto_ml"
COMPILATION = "compilation"
EDGE_PACKAGING = "edge_packaging"
HYPER_PARAMETER_TUNING = "hyper_parameter_tuning"
LABELING = "labeling"
PROCESSING = "processing"
Expand All @@ -78,7 +75,6 @@ class JobType(Enum):
job_type_to_info: Mapping[JobType, Any] = {
JobType.AUTO_ML: AutoMlJobInfo(),
JobType.COMPILATION: CompilationJobInfo(),
JobType.EDGE_PACKAGING: EdgePackagingJobInfo(),
JobType.HYPER_PARAMETER_TUNING: HyperParameterTuningJobInfo(),
JobType.LABELING: LabelingJobInfo(),
JobType.PROCESSING: ProcessingJobInfo(),
Expand Down Expand Up @@ -416,23 +412,20 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
"""
Process outputs from Boto3 describe_auto_ml_job()
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_auto_ml_job
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker/client/describe_auto_ml_job_v2.html
"""

JOB_TYPE = JobType.AUTO_ML

input_datasets = {}

for input_config in job.get("InputDataConfig", []):
for input_config in job.get("AutoMLJobInputDataConfig", []):
input_data = input_config.get("DataSource", {}).get("S3DataSource")

if input_data is not None and "S3Uri" in input_data:
input_datasets[make_s3_urn(input_data["S3Uri"], self.env)] = {
"dataset_type": "s3",
"uri": input_data["S3Uri"],
"datatype": input_data.get("S3DataType"),
}

output_datasets = {}

output_s3_path = job.get("OutputDataConfig", {}).get("S3OutputPath")
Expand All @@ -448,6 +441,18 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
JOB_TYPE,
)

metrics: Dict[str, Any] = {}
# Get job metrics from CandidateMetrics
candidate_metrics = (
job.get("BestCandidate", {})
.get("CandidateProperties", {})
.get("CandidateMetrics", [])
)
if candidate_metrics:
metrics = {
metric["MetricName"]: metric["Value"] for metric in candidate_metrics
}

model_containers = job.get("BestCandidate", {}).get("InferenceContainers", [])

for model_container in model_containers:
Expand All @@ -456,7 +461,7 @@ def process_auto_ml_job(self, job: Dict[str, Any]) -> SageMakerJob:
if model_data_url is not None:
job_key = JobKey(job_snapshot.urn, JobDirection.TRAINING)

self.update_model_image_jobs(model_data_url, job_key)
self.update_model_image_jobs(model_data_url, job_key, metrics=metrics)

return SageMakerJob(
job_name=job_name,
Expand Down Expand Up @@ -515,83 +520,6 @@ def process_compilation_job(self, job: Dict[str, Any]) -> SageMakerJob:
output_datasets=output_datasets,
)

def process_edge_packaging_job(
self,
job: Dict[str, Any],
) -> SageMakerJob:
"""
Process outputs from Boto3 describe_edge_packaging_job()
See https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker.html#SageMaker.Client.describe_edge_packaging_job
"""

JOB_TYPE = JobType.EDGE_PACKAGING

name: str = job["EdgePackagingJobName"]
arn: str = job["EdgePackagingJobArn"]

output_datasets = {}

model_artifact_s3_uri: Optional[str] = job.get("ModelArtifact")
output_s3_uri: Optional[str] = job.get("OutputConfig", {}).get(
"S3OutputLocation"
)

if model_artifact_s3_uri is not None:
output_datasets[make_s3_urn(model_artifact_s3_uri, self.env)] = {
"dataset_type": "s3",
"uri": model_artifact_s3_uri,
}

if output_s3_uri is not None:
output_datasets[make_s3_urn(output_s3_uri, self.env)] = {
"dataset_type": "s3",
"uri": output_s3_uri,
}

# from docs: "The name of the SageMaker Neo compilation job that is used to locate model artifacts that are being packaged."
compilation_job_name: Optional[str] = job.get("CompilationJobName")

output_jobs = set()
if compilation_job_name is not None:
# globally unique job name
full_job_name = ("compilation", compilation_job_name)

if full_job_name in self.name_to_arn:
output_jobs.add(
make_sagemaker_job_urn(
"compilation",
compilation_job_name,
self.name_to_arn[full_job_name],
self.env,
)
)
else:
self.report.report_warning(
name,
f"Unable to find ARN for compilation job {compilation_job_name} produced by edge packaging job {arn}",
)

job_snapshot, job_name, job_arn = self.create_common_job_snapshot(
job,
JOB_TYPE,
f"https://{self.aws_region}.console.aws.amazon.com/sagemaker/home?region={self.aws_region}#/edge-packaging-jobs/{job['EdgePackagingJobName']}",
)

if job.get("ModelName") is not None:
job_key = JobKey(job_snapshot.urn, JobDirection.DOWNSTREAM)

self.update_model_name_jobs(job["ModelName"], job_key)

return SageMakerJob(
job_name=job_name,
job_arn=job_arn,
job_type=JOB_TYPE,
job_snapshot=job_snapshot,
output_datasets=output_datasets,
output_jobs=output_jobs,
)

def process_hyper_parameter_tuning_job(
self,
job: Dict[str, Any],
Expand Down
Loading

0 comments on commit 7869c0d

Please sign in to comment.