Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto batch size for torch model #2318

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- 🔴 Moved around utils functions to clearly separate Darts-specific from non-Darts-specific logic, [#2284](https:/unit8co/darts/pull/2284) by [Dennis Bader](https:/dennisbader):
- Moved function `generate_index()` from `darts.utils.timeseries_generation` to `darts.utils.utils`
- Moved functions `retain_period_common_to_all()`, `series2seq()`, `seq2series()`, `get_single_series()` from `darts.utils.utils` to `darts.utils.ts_utils`.
- Improvements to `TorchForecastingModel`:
- New method `TorchForecastingModel.scale_batch_size()` to find the maximum batch size for fit and predict before memory would run out. [#2318](https:/unit8co/darts/pull/2318) by [Bohdan Bilonoh](https:/BohdanBilonoh)

**Fixed**
- Fixed the order of the features when using component-specific lags so that they are grouped by values, then by components (before, they were grouped by components, then by values). [#2272](https:/unit8co/darts/pull/2272) by [Antoine Madrona](https:/madtoinou).
Expand Down
257 changes: 233 additions & 24 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
import torch
from pytorch_lightning import loggers as pl_loggers
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, Dataset

from darts.dataprocessing.encoders import SequentialEncoder
from darts.logging import (
Expand Down Expand Up @@ -996,28 +996,20 @@ def _setup_for_train(

# Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
# least one batch no matter the chosen batch size
train_loader = DataLoader(
train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
train_loader = self._build_dataloader(
split="train",
dataset=train_dataset,
num_loader_workers=num_loader_workers,
)

# Prepare validation data
val_loader = (
None
if val_dataset is None
else DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
else self._build_dataloader(
split="val",
dataset=val_dataset,
num_loader_workers=num_loader_workers,
)
)

Expand Down Expand Up @@ -1206,6 +1198,168 @@ def lr_find(
update_attr=False,
)

@random_method
def scale_batch_size(
self,
series: Union[TimeSeries, Sequence[TimeSeries]],
n: int = 1,
n_jobs: int = 1,
roll_size: Optional[int] = None,
num_samples: int = 1,
mc_dropout: bool = False,
predict_likelihood_parameters: bool = False,
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
trainer: Optional[pl.Trainer] = None,
verbose: Optional[bool] = None,
method: Literal["fit", "predict"] = "fit",
mode: str = "power",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you not using Literal here? This variable can just be power or linear right?

Copy link
Contributor Author

@BohdanBilonoh BohdanBilonoh Apr 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took it from lightning. I think this is motivated by fact that the mode potentially could be many more modes

steps_per_trial: int = 3,
init_val: int = 2,
max_trials: int = 25,
) -> Optional[int]:
"""
A wrapper around PyTorch Lightning's `Tuner.scale_batch_size()`. Scales the batch size of the model to find the
largest batch size that can be used without running out of memory. For more information on PyTorch Lightning's
Tuner check out
`this link <https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.tuner.tuning.Tuner.html>`_.

Parameters
----------
series
A series or sequence of series serving as target (i.e. what the model will be trained to forecast)
n
The number of time steps after the end of the training time series for which to produce predictions.
Only for the `predict` method.
past_covariates
Optionally, a series or sequence of series specifying past-observed covariates
future_covariates
Optionally, a series or sequence of series specifying future-known covariates
trainer
Optionally, a custom PyTorch-Lightning Trainer object to perform training. Using a custom ``trainer`` will
override Darts' default trainer.
verbose
Optionally, whether to print the progress. Ignored if there is a `ProgressBar` callback in
`pl_trainer_kwargs`.
method
The method to use for scaling the batch size. Can be one of 'fit', 'validate', 'test', or 'predict'.
mode
The mode to use for scaling the batch size. Can be one of 'power' or 'linear'.
steps_per_trial
The number of steps to try for each trial.
init_val
The initial value to start the search with.
max_trials
The maximum number of trials to run.

Returns
-------
int
The largest batch size that can be used without running out of memory.
"""
_, params = self._setup_for_fit_from_dataset(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
val_series=series,
val_past_covariates=past_covariates,
val_future_covariates=future_covariates,
trainer=trainer,
verbose=verbose,
)
trainer, model, train_loader, val_loader = self._setup_for_train(*params)

if method == "predict":
if roll_size is None:
roll_size = self.output_chunk_length
else:
raise_if_not(
0 < roll_size <= self.output_chunk_length,
"`roll_size` must be an integer between 1 and `self.output_chunk_length`.",
)
predict_dataset = self._build_inference_dataset(
target=series,
n=n,
past_covariates=past_covariates,
future_covariates=future_covariates,
stride=0,
bounds=None,
)
model.set_predict_parameters(
n=n,
num_samples=num_samples,
roll_size=roll_size,
batch_size=1,
n_jobs=n_jobs,
predict_likelihood_parameters=predict_likelihood_parameters,
mc_dropout=mc_dropout,
)

build_dataloader = self._build_dataloader

class DataModule(pl.LightningDataModule):
def __init__(self, batch_size):
super().__init__()
self.save_hyperparameters()
self._batch_size = batch_size

@property
def batch_size(self):
return self._batch_size

@batch_size.setter
def batch_size(self, batch_size):
model.set_predict_parameters(
n=n,
num_samples=num_samples,
roll_size=roll_size,
batch_size=batch_size,
n_jobs=n_jobs,
predict_likelihood_parameters=predict_likelihood_parameters,
mc_dropout=mc_dropout,
)
self._batch_size = batch_size

def train_dataloader(self):
return build_dataloader(
split="train",
dataset=train_loader.dataset,
batch_size=self.batch_size,
)

def val_dataloader(self):
return build_dataloader(
split="val",
dataset=val_loader.dataset,
batch_size=self.batch_size,
)

def predict_dataloader(self):
model.set_predict_parameters(
n=n,
num_samples=num_samples,
roll_size=roll_size,
batch_size=self._batch_size,
n_jobs=n_jobs,
predict_likelihood_parameters=predict_likelihood_parameters,
mc_dropout=mc_dropout,
)
return build_dataloader(
split="predict",
dataset=predict_dataset,
batch_size=self.batch_size,
)

return Tuner(trainer).scale_batch_size(
model=model,
datamodule=DataModule(batch_size=init_val),
method=method,
mode=mode,
steps_per_trial=steps_per_trial,
init_val=init_val,
max_trials=max_trials,
)

@random_method
def predict(
self,
Expand Down Expand Up @@ -1487,14 +1641,11 @@ def predict_from_dataset(
mc_dropout=mc_dropout,
)

pred_loader = DataLoader(
input_series_dataset,
pred_loader = self._build_dataloader(
split="predict",
dataset=input_series_dataset,
num_loader_workers=num_loader_workers,
batch_size=batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
)

# set up trainer. use user supplied trainer or create a new trainer from scratch
Expand Down Expand Up @@ -2245,6 +2396,64 @@ def _check_ckpt_parameters(self, tfm_save):

raise_log(ValueError("\n".join(msg)), logger)

def _build_dataloader(
self,
split: Literal["train", "val", "predict"],
dataset: Dataset,
batch_size: Optional[int] = None,
num_loader_workers: int = 0,
) -> DataLoader:
"""
Builds a PyTorch DataLoader from a given dataset.

Parameters
----------
split
The split for which the DataLoader is built. Can be "train", "val" or "predict".
dataset
The dataset from which to build the DataLoader.
batch_size
The batch size for the DataLoader. If not specified, the model's default batch size is used.
num_loader_workers
The number of workers for the DataLoader. Default is 0.
"""

if batch_size is None:
batch_size = self.batch_size

if split == "train":
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
)

if split == "val":
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
)

if split == "predict":
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_loader_workers,
pin_memory=True,
drop_last=False,
collate_fn=self._batch_collate_fn,
)

def __getstate__(self):
# do not pickle the PyTorch LightningModule, and Trainer
return {k: v for k, v in self.__dict__.items() if k not in TFM_ATTRS_NO_PICKLE}
Expand Down
22 changes: 22 additions & 0 deletions darts/tests/models/forecasting/test_torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,6 +1402,28 @@ def test_lr_find(self):
)
assert scores["worst"] > scores["suggested"]

def test_scale_batch_size(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test this for method="fit" and "predict"

train_series, predict_series = self.series[:-40], self.series[-40:]
model = RNNModel(12, "RNN", 10, 10, random_state=42, batch_size=1, **tfm_kwargs)
# find the batch size
init_batch_size = model.batch_size
batch_size = model.scale_batch_size(
series=train_series,
init_val=init_batch_size,
method="fit",
)
assert isinstance(batch_size, int)
assert batch_size != init_batch_size

batch_size = model.scale_batch_size(
series=predict_series,
init_val=init_batch_size,
method="predict",
n=10,
)
assert isinstance(batch_size, int)
assert batch_size != init_batch_size

def test_encoders(self, tmpdir_fn):
series = tg.linear_timeseries(length=10)
pc = tg.linear_timeseries(length=12)
Expand Down
Loading