Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster data transfer in weather-mv rg. #315

Merged
merged 4 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions weather_mv/loader_pipeline/regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@
import dask
import xarray as xr
import xarray_beam as xbeam
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.gcsio import WRITE_CHUNK_SIZE

from .sinks import ToDataSink, open_local
from .sinks import ToDataSink, open_local, copy

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -240,16 +238,15 @@ def apply(self, uri: str):
return

with _metview_op():
logger.debug(f'Copying grib from {uri!r} to local disk.')
logger.info(f'Copying grib from {uri!r} to local disk.')

with open_local(uri) as local_grib:
# TODO(alxr): Figure out way to open fieldset in memory...
logger.debug(f'Regridding {uri!r}.')
logger.info(f'Regridding {uri!r}.')
fs = mv.bindings.Fieldset(path=local_grib)
fieldset = mv.regrid(data=fs, **self.regrid_kwargs)

with tempfile.NamedTemporaryFile() as src:
logger.debug(f'Writing {self.target_from(uri)!r} to local disk.')
logger.info(f'Writing {self.target_from(uri)!r} to local disk.')
if self.to_netcdf:
fieldset.to_dataset().to_netcdf(src.name)
else:
Expand All @@ -260,8 +257,7 @@ def apply(self, uri: str):
_clear_metview()

logger.info(f'Uploading {self.target_from(uri)!r}.')
with FileSystems().create(self.target_from(uri)) as dst:
shutil.copyfileobj(src, dst, WRITE_CHUNK_SIZE)
copy(src.name, self.target_from(uri))

def expand(self, paths):
if not self.zarr:
Expand Down
40 changes: 33 additions & 7 deletions weather_mv/loader_pipeline/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import re
import shutil
import subprocess
import tempfile
import typing as t

Expand All @@ -30,8 +31,7 @@
import numpy as np
import rasterio
import xarray as xr
from apache_beam.io.filesystems import FileSystems
from apache_beam.io.gcp.gcsio import DEFAULT_READ_BUFFER_SIZE
from apache_beam.io.filesystem import CompressionTypes, FileSystem, CompressedFile, DEFAULT_READ_BUFFER_SIZE
from pyproj import Transformer

TIF_TRANSFORM_CRS_TO = "EPSG:4326"
Expand Down Expand Up @@ -338,15 +338,41 @@ def __open_dataset_file(filename: str,
False)


def copy(src: str, dst: str) -> None:
"""Copy data via `gcloud alpha storage` or `gsutil`."""
errors: t.List[subprocess.CalledProcessError] = []
for cmd in ['gcloud alpha storage cp', 'gsutil cp']:
try:
subprocess.run(cmd.split() + [src, dst], check=True, capture_output=True)
return
except subprocess.CalledProcessError as e:
errors.append(e)

msg = f'Failed to copy file {src!r} to {dst!r}'
err_msgs = ', '.join(map(lambda err: repr(err.stderr.decode('utf-8')), errors))
logger.error(f'{msg} due to {err_msgs}.')
raise EnvironmentError(msg, errors)


@contextlib.contextmanager
def open_local(uri: str) -> t.Iterator[str]:
"""Copy a cloud object (e.g. a netcdf, grib, or tif file) from cloud storage, like GCS, to local file."""
with FileSystems().open(uri) as source_file:
with tempfile.NamedTemporaryFile() as dest_file:
shutil.copyfileobj(source_file, dest_file, DEFAULT_READ_BUFFER_SIZE)
dest_file.flush()
dest_file.seek(0)
with tempfile.NamedTemporaryFile() as dest_file:
# Transfer data with gsutil or gcloud alpha storage (when available)
copy(uri, dest_file.name)

# Check if data is compressed. Decompress the data using the same methods that beam's
# FileSystems interface uses.
compression_type = FileSystem._get_compression_type(uri, CompressionTypes.AUTO)
if compression_type == CompressionTypes.UNCOMPRESSED:
yield dest_file.name
return

dest_file.seek(0)
with tempfile.NamedTemporaryFile() as dest_uncompressed:
with CompressedFile(dest_file, compression_type=compression_type) as dcomp:
shutil.copyfileobj(dcomp, dest_uncompressed, DEFAULT_READ_BUFFER_SIZE)
yield dest_uncompressed.name


@contextlib.contextmanager
Expand Down