From 3667f9ecb6d8e3d9805f9462d68ed184fd1d9e44 Mon Sep 17 00:00:00 2001 From: Bohdan Bilonoh Date: Thu, 11 Apr 2024 22:49:05 +0300 Subject: [PATCH 1/4] Auto batch size for torch model --- .../forecasting/torch_forecasting_model.py | 132 ++++++++++++++++++ .../test_torch_forecasting_model.py | 16 +++ 2 files changed, 148 insertions(+) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 955e8fc2de..08942600eb 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -1206,6 +1206,138 @@ def lr_find( update_attr=False, ) + @random_method + def scale_batch_size( + self, + series: Union[TimeSeries, Sequence[TimeSeries]], + val_series: Union[TimeSeries, Sequence[TimeSeries]], + past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + val_future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, + trainer: Optional[pl.Trainer] = None, + verbose: Optional[bool] = None, + epochs: int = 0, + max_samples_per_ts: Optional[int] = None, + num_loader_workers: int = 0, + method: Literal["fit", "validate", "test", "predict"] = "fit", + mode: str = "power", + steps_per_trial: int = 3, + init_val: int = 2, + max_trials: int = 25, + ) -> Optional[int]: + """ + A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Scales the batch size of the model to find the + largest batch size that can be used without running out of memory. For more information on PyTorch Lightning's + Tuner check out + `this link `_. + + Parameters + ---------- + series + A series or sequence of series serving as target (i.e. what the model will be trained to forecast) + past_covariates + Optionally, a series or sequence of series specifying past-observed covariates + future_covariates + Optionally, a series or sequence of series specifying future-known covariates + val_series + Optionally, one or a sequence of validation target series, which will be used to compute the validation + loss throughout training and keep track of the best performing models. + val_past_covariates + Optionally, the past covariates corresponding to the validation series (must match ``covariates``) + val_future_covariates + Optionally, the future covariates corresponding to the validation series (must match ``covariates``) + trainer + Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will + override Darts' default trainer. + verbose + Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in + `pl_trainer_kwargs`. + epochs + If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs`` + was provided to the model constructor. + max_samples_per_ts + Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion + by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily + large number of training samples. This parameter upper-bounds the number of training samples per time + series (taking only the most recent samples in each series). Leaving to None does not apply any + upper bound. + num_loader_workers + Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances, + both for the training and validation loaders (if any). + A larger number of workers can sometimes increase performance, but can also incur extra overheads + and increase memory usage, as more batches are loaded in parallel. + method + The method to use for scaling the batch size. Can be one of 'fit', 'validate', 'test', or 'predict'. + mode + The mode to use for scaling the batch size. Can be one of 'power' or 'linear'. + steps_per_trial + The number of steps to try for each trial. + init_val + The initial value to start the search with. + max_trials + The maximum number of trials to run. + batch_arg_name + The name of the argument to scale in the model. Defaults to 'batch_size'. + + Returns + ------- + int + The largest batch size that can be used without running out of memory. + """ + _, params = self._setup_for_fit_from_dataset( + series=series, + past_covariates=past_covariates, + future_covariates=future_covariates, + val_series=val_series, + val_past_covariates=val_past_covariates, + val_future_covariates=val_future_covariates, + trainer=trainer, + verbose=verbose, + epochs=epochs, + max_samples_per_ts=max_samples_per_ts, + num_loader_workers=num_loader_workers, + ) + trainer, model, train_loader, val_loader = self._setup_for_train(*params) + + class DataModule(pl.LightningDataModule): + def __init__(self, batch_size): + super().__init__() + self.save_hyperparameters() + self.batch_size = batch_size + + def train_dataloader(self): + return DataLoader( + train_loader.dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=train_loader.num_workers, + pin_memory=True, + drop_last=False, + collate_fn=train_loader.collate_fn, + ) + + def val_dataloader(self): + return DataLoader( + val_loader.dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=val_loader.num_workers, + pin_memory=True, + drop_last=False, + collate_fn=val_loader.collate_fn, + ) + + return Tuner(trainer).scale_batch_size( + model=model, + datamodule=DataModule(batch_size=init_val), + method=method, + mode=mode, + steps_per_trial=steps_per_trial, + init_val=init_val, + max_trials=max_trials, + ) + @random_method def predict( self, diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index a3cb7d6c9c..81a9f2368c 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1402,6 +1402,22 @@ def test_lr_find(self): ) assert scores["worst"] > scores["suggested"] + @pytest.mark.slow + def test_scale_batch_size(self): + train_series, val_series = self.series[:-40], self.series[-40:] + model = RNNModel(12, "RNN", 10, 10, random_state=42, batch_size=1, **tfm_kwargs) + # find the batch size + init_batch_size = model.batch_size + batch_size = model.scale_batch_size( + series=train_series, + val_series=val_series, + epochs=50, + init_val=init_batch_size, + ) + assert isinstance(batch_size, int) + assert isinstance(batch_size, int) + assert batch_size != init_batch_size + def test_encoders(self, tmpdir_fn): series = tg.linear_timeseries(length=10) pc = tg.linear_timeseries(length=12) From 46d8a4238986807c3aad5af5e0be08057dde94d0 Mon Sep 17 00:00:00 2001 From: Bohdan Bilonoh Date: Thu, 11 Apr 2024 23:02:55 +0300 Subject: [PATCH 2/4] update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1cbe774182..f1c2854b5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -117,6 +117,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co - 🔴 Moved around utils functions to clearly separate Darts-specific from non-Darts-specific logic, [#2284](https://github.com/unit8co/darts/pull/2284) by [Dennis Bader](https://github.com/dennisbader): - Moved function `generate_index()` from `darts.utils.timeseries_generation` to `darts.utils.utils` - Moved functions `retain_period_common_to_all()`, `series2seq()`, `seq2series()`, `get_single_series()` from `darts.utils.utils` to `darts.utils.ts_utils`. +- Improvements to `TorchForecastingModel`: + - New method `TorchForecastingModel.scale_batch_size()` to find the maximum batch size for fit and predict before memory would run out. [#2318](https://github.com/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https://github.com/BohdanBilonoh) **Fixed** - Fixed the order of the features when using component-specific lags so that they are grouped by values, then by components (before, they were grouped by components, then by values). [#2272](https://github.com/unit8co/darts/pull/2272) by [Antoine Madrona](https://github.com/madtoinou). From 932514e49ae4c14d8026ea6414906df66589cdb3 Mon Sep 17 00:00:00 2001 From: Bohdan Bilonoh Date: Thu, 11 Apr 2024 23:05:14 +0300 Subject: [PATCH 3/4] remove code smell --- darts/tests/models/forecasting/test_torch_forecasting_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 81a9f2368c..835b0e8ce1 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1415,7 +1415,6 @@ def test_scale_batch_size(self): init_val=init_batch_size, ) assert isinstance(batch_size, int) - assert isinstance(batch_size, int) assert batch_size != init_batch_size def test_encoders(self, tmpdir_fn): From ae7a128847ea446bfea4c2917e98a685df5e104f Mon Sep 17 00:00:00 2001 From: Bohdan Bilonoh Date: Tue, 28 May 2024 14:30:20 +0200 Subject: [PATCH 4/4] WIP: stuck with `model.set_predict_parameters` --- .../forecasting/torch_forecasting_model.py | 227 ++++++++++++------ .../test_torch_forecasting_model.py | 15 +- 2 files changed, 163 insertions(+), 79 deletions(-) diff --git a/darts/models/forecasting/torch_forecasting_model.py b/darts/models/forecasting/torch_forecasting_model.py index 08942600eb..65ddfe8a43 100644 --- a/darts/models/forecasting/torch_forecasting_model.py +++ b/darts/models/forecasting/torch_forecasting_model.py @@ -41,7 +41,7 @@ import torch from pytorch_lightning import loggers as pl_loggers from torch import Tensor -from torch.utils.data import DataLoader +from torch.utils.data import DataLoader, Dataset from darts.dataprocessing.encoders import SequentialEncoder from darts.logging import ( @@ -996,28 +996,20 @@ def _setup_for_train( # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at # least one batch no matter the chosen batch size - train_loader = DataLoader( - train_dataset, - batch_size=self.batch_size, - shuffle=True, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, + train_loader = self._build_dataloader( + split="train", + dataset=train_dataset, + num_loader_workers=num_loader_workers, ) # Prepare validation data val_loader = ( None if val_dataset is None - else DataLoader( - val_dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, + else self._build_dataloader( + split="val", + dataset=val_dataset, + num_loader_workers=num_loader_workers, ) ) @@ -1210,17 +1202,17 @@ def lr_find( def scale_batch_size( self, series: Union[TimeSeries, Sequence[TimeSeries]], - val_series: Union[TimeSeries, Sequence[TimeSeries]], + n: int = 1, + n_jobs: int = 1, + roll_size: Optional[int] = None, + num_samples: int = 1, + mc_dropout: bool = False, + predict_likelihood_parameters: bool = False, past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, - val_past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, - val_future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None, trainer: Optional[pl.Trainer] = None, verbose: Optional[bool] = None, - epochs: int = 0, - max_samples_per_ts: Optional[int] = None, - num_loader_workers: int = 0, - method: Literal["fit", "validate", "test", "predict"] = "fit", + method: Literal["fit", "predict"] = "fit", mode: str = "power", steps_per_trial: int = 3, init_val: int = 2, @@ -1236,37 +1228,19 @@ def scale_batch_size( ---------- series A series or sequence of series serving as target (i.e. what the model will be trained to forecast) + n + The number of time steps after the end of the training time series for which to produce predictions. + Only for the `predict` method. past_covariates Optionally, a series or sequence of series specifying past-observed covariates future_covariates Optionally, a series or sequence of series specifying future-known covariates - val_series - Optionally, one or a sequence of validation target series, which will be used to compute the validation - loss throughout training and keep track of the best performing models. - val_past_covariates - Optionally, the past covariates corresponding to the validation series (must match ``covariates``) - val_future_covariates - Optionally, the future covariates corresponding to the validation series (must match ``covariates``) trainer Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will override Darts' default trainer. verbose Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in `pl_trainer_kwargs`. - epochs - If specified, will train the model for ``epochs`` (additional) epochs, irrespective of what ``n_epochs`` - was provided to the model constructor. - max_samples_per_ts - Optionally, a maximum number of samples to use per time series. Models are trained in a supervised fashion - by constructing slices of (input, output) examples. On long time series, this can result in unnecessarily - large number of training samples. This parameter upper-bounds the number of training samples per time - series (taking only the most recent samples in each series). Leaving to None does not apply any - upper bound. - num_loader_workers - Optionally, an integer specifying the ``num_workers`` to use in PyTorch ``DataLoader`` instances, - both for the training and validation loaders (if any). - A larger number of workers can sometimes increase performance, but can also incur extra overheads - and increase memory usage, as more batches are loaded in parallel. method The method to use for scaling the batch size. Can be one of 'fit', 'validate', 'test', or 'predict'. mode @@ -1277,8 +1251,6 @@ def scale_batch_size( The initial value to start the search with. max_trials The maximum number of trials to run. - batch_arg_name - The name of the argument to scale in the model. Defaults to 'batch_size'. Returns ------- @@ -1289,43 +1261,93 @@ def scale_batch_size( series=series, past_covariates=past_covariates, future_covariates=future_covariates, - val_series=val_series, - val_past_covariates=val_past_covariates, - val_future_covariates=val_future_covariates, + val_series=series, + val_past_covariates=past_covariates, + val_future_covariates=future_covariates, trainer=trainer, verbose=verbose, - epochs=epochs, - max_samples_per_ts=max_samples_per_ts, - num_loader_workers=num_loader_workers, ) trainer, model, train_loader, val_loader = self._setup_for_train(*params) + if method == "predict": + if roll_size is None: + roll_size = self.output_chunk_length + else: + raise_if_not( + 0 < roll_size <= self.output_chunk_length, + "`roll_size` must be an integer between 1 and `self.output_chunk_length`.", + ) + predict_dataset = self._build_inference_dataset( + target=series, + n=n, + past_covariates=past_covariates, + future_covariates=future_covariates, + stride=0, + bounds=None, + ) + model.set_predict_parameters( + n=n, + num_samples=num_samples, + roll_size=roll_size, + batch_size=1, + n_jobs=n_jobs, + predict_likelihood_parameters=predict_likelihood_parameters, + mc_dropout=mc_dropout, + ) + + build_dataloader = self._build_dataloader + class DataModule(pl.LightningDataModule): def __init__(self, batch_size): super().__init__() self.save_hyperparameters() - self.batch_size = batch_size + self._batch_size = batch_size + + @property + def batch_size(self): + return self._batch_size + + @batch_size.setter + def batch_size(self, batch_size): + model.set_predict_parameters( + n=n, + num_samples=num_samples, + roll_size=roll_size, + batch_size=batch_size, + n_jobs=n_jobs, + predict_likelihood_parameters=predict_likelihood_parameters, + mc_dropout=mc_dropout, + ) + self._batch_size = batch_size def train_dataloader(self): - return DataLoader( - train_loader.dataset, + return build_dataloader( + split="train", + dataset=train_loader.dataset, batch_size=self.batch_size, - shuffle=True, - num_workers=train_loader.num_workers, - pin_memory=True, - drop_last=False, - collate_fn=train_loader.collate_fn, ) def val_dataloader(self): - return DataLoader( - val_loader.dataset, + return build_dataloader( + split="val", + dataset=val_loader.dataset, + batch_size=self.batch_size, + ) + + def predict_dataloader(self): + model.set_predict_parameters( + n=n, + num_samples=num_samples, + roll_size=roll_size, + batch_size=self._batch_size, + n_jobs=n_jobs, + predict_likelihood_parameters=predict_likelihood_parameters, + mc_dropout=mc_dropout, + ) + return build_dataloader( + split="predict", + dataset=predict_dataset, batch_size=self.batch_size, - shuffle=False, - num_workers=val_loader.num_workers, - pin_memory=True, - drop_last=False, - collate_fn=val_loader.collate_fn, ) return Tuner(trainer).scale_batch_size( @@ -1619,14 +1641,11 @@ def predict_from_dataset( mc_dropout=mc_dropout, ) - pred_loader = DataLoader( - input_series_dataset, + pred_loader = self._build_dataloader( + split="predict", + dataset=input_series_dataset, + num_loader_workers=num_loader_workers, batch_size=batch_size, - shuffle=False, - num_workers=num_loader_workers, - pin_memory=True, - drop_last=False, - collate_fn=self._batch_collate_fn, ) # set up trainer. use user supplied trainer or create a new trainer from scratch @@ -2377,6 +2396,64 @@ def _check_ckpt_parameters(self, tfm_save): raise_log(ValueError("\n".join(msg)), logger) + def _build_dataloader( + self, + split: Literal["train", "val", "predict"], + dataset: Dataset, + batch_size: Optional[int] = None, + num_loader_workers: int = 0, + ) -> DataLoader: + """ + Builds a PyTorch DataLoader from a given dataset. + + Parameters + ---------- + split + The split for which the DataLoader is built. Can be "train", "val" or "predict". + dataset + The dataset from which to build the DataLoader. + batch_size + The batch size for the DataLoader. If not specified, the model's default batch size is used. + num_loader_workers + The number of workers for the DataLoader. Default is 0. + """ + + if batch_size is None: + batch_size = self.batch_size + + if split == "train": + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=True, + num_workers=num_loader_workers, + pin_memory=True, + drop_last=False, + collate_fn=self._batch_collate_fn, + ) + + if split == "val": + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_loader_workers, + pin_memory=True, + drop_last=False, + collate_fn=self._batch_collate_fn, + ) + + if split == "predict": + return DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=False, + num_workers=num_loader_workers, + pin_memory=True, + drop_last=False, + collate_fn=self._batch_collate_fn, + ) + def __getstate__(self): # do not pickle the PyTorch LightningModule, and Trainer return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE} diff --git a/darts/tests/models/forecasting/test_torch_forecasting_model.py b/darts/tests/models/forecasting/test_torch_forecasting_model.py index 835b0e8ce1..73b57a551f 100644 --- a/darts/tests/models/forecasting/test_torch_forecasting_model.py +++ b/darts/tests/models/forecasting/test_torch_forecasting_model.py @@ -1402,17 +1402,24 @@ def test_lr_find(self): ) assert scores["worst"] > scores["suggested"] - @pytest.mark.slow def test_scale_batch_size(self): - train_series, val_series = self.series[:-40], self.series[-40:] + train_series, predict_series = self.series[:-40], self.series[-40:] model = RNNModel(12, "RNN", 10, 10, random_state=42, batch_size=1, **tfm_kwargs) # find the batch size init_batch_size = model.batch_size batch_size = model.scale_batch_size( series=train_series, - val_series=val_series, - epochs=50, init_val=init_batch_size, + method="fit", + ) + assert isinstance(batch_size, int) + assert batch_size != init_batch_size + + batch_size = model.scale_batch_size( + series=predict_series, + init_val=init_batch_size, + method="predict", + n=10, ) assert isinstance(batch_size, int) assert batch_size != init_batch_size