From 0d052e848f2a2f574d33319fed574cde0b0e4e38 Mon Sep 17 00:00:00 2001 From: Oleg Avdeev Date: Fri, 20 Aug 2021 08:40:55 -0700 Subject: [PATCH] allow to mount host volumes in AWS Batch (#640) --- metaflow/plugins/aws/batch/batch.py | 11 +++++---- metaflow/plugins/aws/batch/batch_cli.py | 5 +++- metaflow/plugins/aws/batch/batch_client.py | 24 +++++++++++++++---- metaflow/plugins/aws/batch/batch_decorator.py | 3 ++- .../aws/step_functions/step_functions.py | 3 ++- 5 files changed, 35 insertions(+), 11 deletions(-) diff --git a/metaflow/plugins/aws/batch/batch.py b/metaflow/plugins/aws/batch/batch.py index 9d72b41b82..e887508c38 100644 --- a/metaflow/plugins/aws/batch/batch.py +++ b/metaflow/plugins/aws/batch/batch.py @@ -164,7 +164,8 @@ def create_job( max_swap=None, swappiness=None, env={}, - attrs={} + attrs={}, + host_volumes=None, ): job_name = self._job_name( attrs.get('metaflow.user'), @@ -186,7 +187,7 @@ def create_job( .execution_role(execution_role) \ .job_def(image, iam_role, queue, execution_role, shared_memory, - max_swap, swappiness) \ + max_swap, swappiness, host_volumes=host_volumes) \ .cpu(cpu) \ .gpu(gpu) \ .memory(memory) \ @@ -244,6 +245,7 @@ def launch_job( shared_memory=None, max_swap=None, swappiness=None, + host_volumes=None, env={}, attrs={}, ): @@ -272,8 +274,9 @@ def launch_job( shared_memory, max_swap, swappiness, - env, - attrs + env=env, + attrs=attrs, + host_volumes=host_volumes ) self.job = job.execute() diff --git a/metaflow/plugins/aws/batch/batch_cli.py b/metaflow/plugins/aws/batch/batch_cli.py index 0f80188a23..7c359f3b53 100644 --- a/metaflow/plugins/aws/batch/batch_cli.py +++ b/metaflow/plugins/aws/batch/batch_cli.py @@ -169,6 +169,7 @@ def kill(ctx, run_id, user, my_runs): @click.option("--swappiness", help="Swappiness requirement for AWS Batch.") #TODO: Maybe remove it altogether since it's not used here @click.option('--ubf-context', default=None, type=click.Choice([None])) +@click.option('--mount-host-volumes', default=None) @click.pass_context def step( ctx, @@ -187,6 +188,7 @@ def step( shared_memory=None, max_swap=None, swappiness=None, + host_volumes=None, **kwargs ): def echo(msg, stream='stderr', batch_id=None): @@ -294,7 +296,8 @@ def echo(msg, stream='stderr', batch_id=None): max_swap=max_swap, swappiness=swappiness, env=env, - attrs=attrs + attrs=attrs, + host_volumes=host_volumes, ) except Exception as e: print(e) diff --git a/metaflow/plugins/aws/batch/batch_client.py b/metaflow/plugins/aws/batch/batch_client.py index c4c1ed6533..8022980267 100644 --- a/metaflow/plugins/aws/batch/batch_client.py +++ b/metaflow/plugins/aws/batch/batch_client.py @@ -101,7 +101,8 @@ def _register_job_definition(self, execution_role, shared_memory, max_swap, - swappiness): + swappiness, + host_volumes): # identify platform from any compute environment associated with the # queue if AWS_SANDBOX_ENABLED: @@ -189,6 +190,19 @@ def _register_job_definition(self, job_definition['containerProperties'] \ ['linuxParameters']['maxSwap'] = int(max_swap) + if host_volumes: + volume_paths = host_volumes.split(',') + job_definition['containerProperties']['volumes'] = [] + job_definition['containerProperties']['mountPoints'] = [] + for host_path in volume_paths: + name = host_path.replace('/', '_') + job_definition['containerProperties']['volumes'].append( + {'name': name, 'host': {'sourcePath': host_path}} + ) + job_definition['containerProperties']['mountPoints'].append( + {"sourceVolume": name, "containerPath": host_path} + ) + # check if job definition already exists def_name = 'metaflow_%s' % \ hashlib.sha224(str(job_definition).encode('utf-8')).hexdigest() @@ -219,7 +233,8 @@ def job_def(self, execution_role, shared_memory, max_swap, - swappiness): + swappiness, + host_volumes): self.payload['jobDefinition'] = \ self._register_job_definition(image, iam_role, @@ -227,7 +242,8 @@ def job_def(self, execution_role, shared_memory, max_swap, - swappiness) + swappiness, + host_volumes) return self def job_name(self, job_name): @@ -629,4 +645,4 @@ def _fill_buf(self): events = self._get_events() for event in events: self._buf.append(event['message']) - self._pos = event['timestamp'] \ No newline at end of file + self._pos = event['timestamp'] diff --git a/metaflow/plugins/aws/batch/batch_decorator.py b/metaflow/plugins/aws/batch/batch_decorator.py index 1be46d1ed7..389ee490a2 100644 --- a/metaflow/plugins/aws/batch/batch_decorator.py +++ b/metaflow/plugins/aws/batch/batch_decorator.py @@ -99,7 +99,8 @@ def my_step(self): 'execution_role': ECS_FARGATE_EXECUTION_ROLE, 'shared_memory': None, 'max_swap': None, - 'swappiness': None + 'swappiness': None, + 'host_volumes': None, } package_url = None package_sha = None diff --git a/metaflow/plugins/aws/step_functions/step_functions.py b/metaflow/plugins/aws/step_functions/step_functions.py index 167cd96fce..0b489c30c3 100644 --- a/metaflow/plugins/aws/step_functions/step_functions.py +++ b/metaflow/plugins/aws/step_functions/step_functions.py @@ -613,7 +613,8 @@ def _batch(self, node): max_swap=resources['max_swap'], swappiness=resources['swappiness'], env=env, - attrs=attrs + attrs=attrs, + host_volumes=resources['host_volumes'], ) \ .attempts(total_retries + 1)