Skip to content

Commit

Permalink
allow to mount host volumes in AWS Batch
Browse files Browse the repository at this point in the history
  • Loading branch information
oavdeev committed Aug 20, 2021
1 parent 9f832e6 commit 02fc1cd
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 11 deletions.
11 changes: 7 additions & 4 deletions metaflow/plugins/aws/batch/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def create_job(
max_swap=None,
swappiness=None,
env={},
attrs={}
attrs={},
host_volumes=None,
):
job_name = self._job_name(
attrs.get('metaflow.user'),
Expand All @@ -186,7 +187,7 @@ def create_job(
.execution_role(execution_role) \
.job_def(image, iam_role,
queue, execution_role, shared_memory,
max_swap, swappiness) \
max_swap, swappiness, host_volumes=host_volumes) \
.cpu(cpu) \
.gpu(gpu) \
.memory(memory) \
Expand Down Expand Up @@ -244,6 +245,7 @@ def launch_job(
shared_memory=None,
max_swap=None,
swappiness=None,
host_volumes=None,
env={},
attrs={},
):
Expand Down Expand Up @@ -272,8 +274,9 @@ def launch_job(
shared_memory,
max_swap,
swappiness,
env,
attrs
env=env,
attrs=attrs,
host_volumes=host_volumes
)
self.job = job.execute()

Expand Down
5 changes: 4 additions & 1 deletion metaflow/plugins/aws/batch/batch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def kill(ctx, run_id, user, my_runs):
@click.option("--swappiness", help="Swappiness requirement for AWS Batch.")
#TODO: Maybe remove it altogether since it's not used here
@click.option('--ubf-context', default=None, type=click.Choice([None]))
@click.option('--mount-host-volumes', default=None)
@click.pass_context
def step(
ctx,
Expand All @@ -187,6 +188,7 @@ def step(
shared_memory=None,
max_swap=None,
swappiness=None,
host_volumes=None,
**kwargs
):
def echo(msg, stream='stderr', batch_id=None):
Expand Down Expand Up @@ -294,7 +296,8 @@ def echo(msg, stream='stderr', batch_id=None):
max_swap=max_swap,
swappiness=swappiness,
env=env,
attrs=attrs
attrs=attrs,
host_volumes=host_volumes,
)
except Exception as e:
print(e)
Expand Down
24 changes: 20 additions & 4 deletions metaflow/plugins/aws/batch/batch_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def _register_job_definition(self,
execution_role,
shared_memory,
max_swap,
swappiness):
swappiness,
host_volumes):
# identify platform from any compute environment associated with the
# queue
if AWS_SANDBOX_ENABLED:
Expand Down Expand Up @@ -189,6 +190,19 @@ def _register_job_definition(self,
job_definition['containerProperties'] \
['linuxParameters']['maxSwap'] = int(max_swap)

if host_volumes:
volume_paths = host_volumes.split(',')
job_definition['containerProperties']['volumes'] = []
job_definition['containerProperties']['mountPoints'] = []
for host_path in volume_paths:
name = host_path.replace('/', '_')
job_definition['containerProperties']['volumes'].append(
{'name': name, 'host': {'sourcePath': host_path}}
)
job_definition['containerProperties']['mountPoints'].append(
{"sourceVolume": name, "containerPath": host_path}
)

# check if job definition already exists
def_name = 'metaflow_%s' % \
hashlib.sha224(str(job_definition).encode('utf-8')).hexdigest()
Expand Down Expand Up @@ -219,15 +233,17 @@ def job_def(self,
execution_role,
shared_memory,
max_swap,
swappiness):
swappiness,
host_volumes):
self.payload['jobDefinition'] = \
self._register_job_definition(image,
iam_role,
job_queue,
execution_role,
shared_memory,
max_swap,
swappiness)
swappiness,
host_volumes)
return self

def job_name(self, job_name):
Expand Down Expand Up @@ -629,4 +645,4 @@ def _fill_buf(self):
events = self._get_events()
for event in events:
self._buf.append(event['message'])
self._pos = event['timestamp']
self._pos = event['timestamp']
3 changes: 2 additions & 1 deletion metaflow/plugins/aws/batch/batch_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def my_step(self):
'execution_role': ECS_FARGATE_EXECUTION_ROLE,
'shared_memory': None,
'max_swap': None,
'swappiness': None
'swappiness': None,
'host_volumes': None,
}
package_url = None
package_sha = None
Expand Down
3 changes: 2 additions & 1 deletion metaflow/plugins/aws/step_functions/step_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,8 @@ def _batch(self, node):
max_swap=resources['max_swap'],
swappiness=resources['swappiness'],
env=env,
attrs=attrs
attrs=attrs,
host_volumes=resources['host_volumes'],
) \
.attempts(total_retries + 1)

Expand Down

0 comments on commit 02fc1cd

Please sign in to comment.