Skip to content

Commit

Permalink
Support monthly and weekly data for preprocess module (#173)
Browse files Browse the repository at this point in the history
* add support for monthly data

* fix tests

* add support for weekly data and update notebook

* add tests for monthly and weekly data

* update tests

* test data resolution and given timescale

* make linter happy

* make isort happy

* Apply suggestions from code review

Co-authored-by: Bart Schilperoort <[email protected]>

* address comments and remove python 3.8 support

* use TypeAlias for timescale type

* fixed issue with conditional: if not max_days >= temporal_resolution >= min_days

* Add future annotation to enable typealias on python 3.9

* fix import

* spell out literal for timescale type

---------

Co-authored-by: Bart Schilperoort <[email protected]>
Co-authored-by: semvijverberg <[email protected]>
  • Loading branch information
3 people authored Aug 18, 2023
1 parent 77ec4a6 commit c060bc6
Show file tree
Hide file tree
Showing 5 changed files with 1,290 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
python-version: ['3.8', '3.9', '3.10']
python-version: ['3.9', '3.10']
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
1,170 changes: 1,154 additions & 16 deletions docs/notebooks/tutorial_preprocessing.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ dynamic = ["version"]
description = "python package for s2s forecasts with ai"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.8,<3.11"
requires-python = ">3.8,<3.11"
authors = [
{email = "[email protected]"},
{name = "Yang Liu, Bart Schilperoort, Peter Kalverla, Jannes van Ingen, Sem Vijverberg"}
Expand All @@ -30,7 +30,6 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Natural Language :: English",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
]
Expand Down
102 changes: 91 additions & 11 deletions s2spy/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Preprocessor for s2spy workflow."""
import warnings
from typing import Literal
from typing import Tuple
from typing import Union
import numpy as np
Expand Down Expand Up @@ -61,9 +63,39 @@ def _subtract_trend(data: Union[xr.DataArray, xr.Dataset], method: str, trend: d
raise NotImplementedError


def _get_climatology(data: Union[xr.Dataset, xr.DataArray]):
def _get_climatology(
data: Union[xr.Dataset, xr.DataArray],
timescale: Literal["monthly", "weekly", "daily"],
):
"""Calculate the climatology of timeseries data."""
return data.groupby("time.dayofyear").mean("time")
_check_data_resolution_match(data, timescale)
if timescale == "monthly":
climatology = data.groupby("time.month").mean("time")
elif timescale == "weekly":
climatology = data.groupby(data["time"].dt.isocalendar().week).mean("time")
elif timescale == "daily":
climatology = data.groupby("time.dayofyear").mean("time")
else:
raise ValueError("Given timescale is not supported.")

return climatology


def _subtract_climatology(
data: Union[xr.Dataset, xr.DataArray],
timescale: Literal["monthly", "weekly", "daily"],
climatology: Union[xr.Dataset, xr.DataArray],
):
if timescale == "monthly":
deseasonalized = data.groupby("time.month") - climatology
elif timescale == "weekly":
deseasonalized = data.groupby(data["time"].dt.isocalendar().week) - climatology
elif timescale == "daily":
deseasonalized = data.groupby("time.dayofyear") - climatology
else:
raise ValueError("Given timescale is not supported.")

return deseasonalized


def _check_input_data(data: Union[xr.DataArray, xr.Dataset]):
Expand All @@ -88,12 +120,56 @@ def _check_input_data(data: Union[xr.DataArray, xr.Dataset]):
)


def _check_temporal_resolution(
timescale: Literal["monthly", "weekly", "daily"]
) -> Literal["monthly", "weekly", "daily"]:
support_temporal_resolution = ["monthly", "weekly", "daily"]
if timescale not in support_temporal_resolution:
raise ValueError(
"Given temporal resoltuion is not supported."
"Please choose from 'monthly', 'weekly', 'daily'."
)
return timescale


def _check_data_resolution_match(
data: Union[xr.DataArray, xr.Dataset],
timescale: Literal["monthly", "weekly", "daily"],
):
"""Check if the temporal resolution of input is the same as given timescale."""
timescale_dict = {
"monthly": np.timedelta64(1, "M"),
"weekly": np.timedelta64(1, "W"),
"daily": np.timedelta64(1, "D"),
}
time_intervals = np.diff(data["time"].to_numpy())
temporal_resolution = np.median(time_intervals).astype("timedelta64[D]")
if timescale == "monthly":
temporal_resolution = temporal_resolution.astype(int)
min_days, max_days = (28, 31)
if not max_days >= temporal_resolution >= min_days:
warnings.warn(
"The temporal resolution of data does not completely match "
"the target timescale. Please check your input data.",
stacklevel=1,
)

elif timescale in timescale_dict:
if timescale_dict[timescale].astype("timedelta64[D]") != temporal_resolution:
warnings.warn(
"The temporal resolution of data does not completely match "
"the target timescale. Please check your input data.",
stacklevel=1,
)


class Preprocessor:
"""Preprocessor for s2s data."""

def __init__(
def __init__( # noqa: PLR0913
self,
rolling_window_size: Union[int, None],
timescale: Literal["monthly", "weekly", "daily"],
rolling_min_periods: int = 1,
subtract_climatology: bool = True,
detrend: Union[str, None] = "linear",
Expand Down Expand Up @@ -121,11 +197,14 @@ def __init__(
climatology of the data. Defaults to True.
detrend (optional): Which method to use for detrending. Currently the only method
supported is "linear". If you want to skip detrending, set this to None.
timescale: Temporal resolution of input data.
"""
self._window_size = rolling_window_size
self._min_periods = rolling_min_periods
self._detrend = detrend
self._subtract_climatology = subtract_climatology
if subtract_climatology:
self._timescale = _check_temporal_resolution(timescale)

self._climatology: Union[xr.DataArray, xr.Dataset]
self._trend: dict
Expand All @@ -149,15 +228,16 @@ def fit(self, data: Union[xr.DataArray, xr.Dataset]) -> None:
data_rolling = data

if self._subtract_climatology:
self._climatology = _get_climatology(data_rolling)
self._climatology = _get_climatology(data_rolling, self._timescale)

if self._detrend is not None:
self._trend = _get_trend(
data_rolling.groupby("time.dayofyear") - self._climatology
if self._subtract_climatology
else data_rolling,
self._detrend,
)
if self._subtract_climatology:
deseasonalized = _subtract_climatology(
data_rolling, self._timescale, self._climatology
)
self._trend = _get_trend(deseasonalized, self._detrend)
else:
self._trend = _get_trend(data_rolling, self._detrend)

self._is_fit = True

Expand All @@ -179,7 +259,7 @@ def transform(
)

if self._subtract_climatology:
d = data.groupby("time.dayofyear") - self._climatology
d = _subtract_climatology(data, self._timescale, self._climatology)
else:
d = data

Expand Down
48 changes: 43 additions & 5 deletions tests/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""Tests for the s2spy.preprocess module.
"""
"""Tests for the s2spy.preprocess module."""
import numpy as np
import pytest
import scipy.signal
Expand Down Expand Up @@ -55,14 +54,47 @@ def test_get_and_subtract_linear_trend(self, raw_field):
result = preprocess._subtract_trend(raw_field, "linear", trend)
np.testing.assert_array_almost_equal(result["sst"], expected["sst"])

def test_get_climatology(self, raw_field):
result = preprocess._get_climatology(raw_field)
def test_check_temporal_resolution(self):
with pytest.raises(ValueError):
preprocess._check_temporal_resolution("hourly") # type: ignore

def test_get_climatology_daily(self, raw_field):
result = preprocess._get_climatology(raw_field, timescale="daily")
expected = (
raw_field["sst"].sel(time=slice("2010-01-01", "2010-12-31")).data
+ raw_field["sst"].sel(time=slice("2011-01-01", "2011-12-31")).data
) / 2
np.testing.assert_array_almost_equal(result["sst"], expected)

def test_get_climatology_weekly(self, raw_field):
raw_field_weekly = raw_field.resample(time="W").mean()
result = preprocess._get_climatology(raw_field_weekly, timescale="weekly")
# need to consider the actual calendar week number for the expected climatology
raw_field_weekly["time"] = raw_field_weekly["time"].dt.isocalendar().week
expected = raw_field_weekly.groupby("time").mean()
np.testing.assert_array_almost_equal(result["sst"], expected["sst"])

def test_get_climatology_monthly(self, raw_field):
raw_field_monthly = raw_field.resample(time="M").mean()
result = preprocess._get_climatology(raw_field_monthly, timescale="monthly")
expected = (
raw_field_monthly["sst"].sel(time=slice("2010-01-01", "2010-12-31")).data
+ raw_field_monthly["sst"].sel(time=slice("2011-01-01", "2011-12-31")).data
) / 2
np.testing.assert_array_almost_equal(result["sst"], expected)

def test_get_climatology_wrong_timescale(self, raw_field):
with pytest.raises(ValueError):
preprocess._get_climatology(raw_field, timescale="hourly") # type: ignore

@pytest.mark.parametrize("timescale", ("weekly", "monthly"))
def test_check_data_resolution_mismatch(self, raw_field, timescale):
with pytest.warns(UserWarning):
preprocess._check_data_resolution_match(raw_field, timescale)

def test_check_data_resolution_match(self, raw_field):
preprocess._check_data_resolution_match(raw_field, "daily")


class TestPreprocessor:
"""Test preprocessor."""
Expand All @@ -71,6 +103,7 @@ class TestPreprocessor:
def preprocessor(self):
prep = preprocess.Preprocessor(
rolling_window_size=25,
timescale="daily",
detrend="linear",
subtract_climatology=True,
)
Expand All @@ -80,6 +113,7 @@ def preprocessor(self):
def preprocessor_no_rolling(self, request):
prep = preprocess.Preprocessor(
rolling_window_size=request.param,
timescale="monthly",
detrend=None,
subtract_climatology=True,
)
Expand All @@ -89,6 +123,7 @@ def preprocessor_no_rolling(self, request):
def preprocessor_no_climatology(self):
prep = preprocess.Preprocessor(
rolling_window_size=25,
timescale="daily",
detrend="linear",
subtract_climatology=False,
)
Expand All @@ -98,6 +133,7 @@ def preprocessor_no_climatology(self):
def preprocessor_no_detrend(self):
prep = preprocess.Preprocessor(
rolling_window_size=25,
timescale="daily",
detrend=None,
subtract_climatology=True,
)
Expand All @@ -116,6 +152,7 @@ def raw_field(self):
def test_init(self):
prep = preprocess.Preprocessor(
rolling_window_size=25,
timescale="weekly",
detrend="linear",
subtract_climatology=True,
)
Expand All @@ -130,7 +167,7 @@ def test_fit(self, preprocessor, raw_field):
def test_fit_no_rolling(self, preprocessor_no_rolling, raw_field):
preprocessor_no_rolling.fit(raw_field)
assert preprocessor_no_rolling.climatology == preprocess._get_climatology(
raw_field
raw_field, timescale="daily"
)

def test_transform(self, preprocessor, raw_field):
Expand Down Expand Up @@ -163,6 +200,7 @@ def test_fit_and_transform_no_detrend(self, preprocessor_no_detrend, raw_field):
def test_fit_and_transform_no_climatology_and_detrend(self, raw_field):
prep = preprocess.Preprocessor(
rolling_window_size=10,
timescale="daily",
detrend=None,
subtract_climatology=False,
)
Expand Down

0 comments on commit c060bc6

Please sign in to comment.