Skip to content

Commit

Permalink
Merge pull request #2 from sachahu1/feature/framework-improvements
Browse files Browse the repository at this point in the history
Improved Framework
  • Loading branch information
sachahu1 authored Aug 2, 2024
2 parents e8ceeeb + 61741bc commit 2401f12
Show file tree
Hide file tree
Showing 10 changed files with 450 additions and 385 deletions.
55 changes: 31 additions & 24 deletions diffusion_models/diffusion_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,13 @@
from typing import Optional

import torch
from torch import amp
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from tqdm import tqdm

from diffusion_models.gaussian_diffusion.beta_schedulers import (
BaseBetaScheduler,
)
from diffusion_models.gaussian_diffusion.beta_schedulers import (
LinearBetaScheduler,
)
from diffusion_models.gaussian_diffusion.gaussian_diffuser import (
GaussianDiffuser,
)
from diffusion_models.models.base_diffusion_model import BaseDiffusionModel
from diffusion_models.utils.schemas import BetaSchedulerConfiguration
from diffusion_models.utils.schemas import Checkpoint
from diffusion_models.utils.schemas import LogConfiguration
from diffusion_models.utils.schemas import TrainingConfiguration
Expand All @@ -28,12 +20,11 @@
class DiffusionTrainer:
def __init__(
self,
model: torch.nn.Module,
model: BaseDiffusionModel,
dataset: Dataset,
optimizer: torch.optim.Optimizer,
training_configuration: TrainingConfiguration,
loss_function: Callable = F.l1_loss,
beta_scheduler: BaseBetaScheduler = LinearBetaScheduler(),
scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
log_configuration: LogConfiguration = LogConfiguration(),
reverse_transforms: Callable = lambda x: x,
Expand All @@ -46,11 +37,6 @@ def __init__(
self.scheduler = scheduler
self.device = device

self.beta_scheduler = beta_scheduler
self.gaussian_diffuser = GaussianDiffuser(
beta_scheduler=beta_scheduler,
).to(device)

self.dataloader = DataLoader(
dataset=dataset,
batch_size=training_configuration.batch_size,
Expand All @@ -63,7 +49,10 @@ def __init__(

self._image_shape = dataset[0][0].shape

self.scaler = torch.cuda.amp.GradScaler()
self.scaler = torch.amp.GradScaler(
device=device
# init_scale=8192,
)

self.log_configuration = log_configuration

Expand All @@ -79,6 +68,8 @@ def __init__(

self.reverse_transforms = reverse_transforms

torch.backends.cudnn.benchmark = True

def save_checkpoint(self, epoch: int, checkpoint_name: str):
checkpoint = Checkpoint(
epoch=epoch,
Expand All @@ -87,6 +78,12 @@ def save_checkpoint(self, epoch: int, checkpoint_name: str):
scaler=self.scaler.state_dict()
if self.training_configuration.mixed_precision_training
else None,
image_channels=self._image_shape[0],
beta_scheduler_config=BetaSchedulerConfiguration(
steps=self.model.diffuser.beta_scheduler.steps,
betas=self.model.diffuser.beta_scheduler.betas,
alpha_bars=self.model.diffuser.beta_scheduler.alpha_bars,
),
tensorboard_run_name=self.tensorboard_manager.summary_writer.log_dir,
)
checkpoint.to_file(self.checkpoint_path / checkpoint_name)
Expand All @@ -102,20 +99,29 @@ def train(self):
images, _ = batch
images = images.to(self.device)

noisy_images, noise, timesteps = self.gaussian_diffuser.diffuse_batch(
images=images
)
noisy_images, noise, timesteps = self.model.diffuse(images=images)

with amp.autocast(
self.optimizer.zero_grad(set_to_none=True)

with torch.autocast(
device_type=self.device,
enabled=self.training_configuration.mixed_precision_training,
):
prediction = self.model(noisy_images, timesteps)
loss = self.loss_function(noise, prediction)

self.optimizer.zero_grad(set_to_none=True)
self.scaler.scale(loss).backward()

if self.training_configuration.gradient_clip is not None:
# Unscales the gradients of optimizer's assigned params in-place
self.scaler.unscale_(self.optimizer)

# Since the gradients of optimizer's assigned params are unscaled, clips as usual:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.training_configuration.gradient_clip,
)

self.scaler.step(self.optimizer)
self.scaler.update()

Expand All @@ -134,6 +140,7 @@ def train(self):

@torch.no_grad()
def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int):
self.model.eval()
if global_step % self.log_configuration.log_rate == 0:
self.tensorboard_manager.log_metrics(
metrics=metrics, global_step=global_step
Expand All @@ -152,7 +159,7 @@ def log_to_tensorboard(self, metrics: Dict[str, float], global_step: int):
),
device=self.device,
)
images = self.gaussian_diffuser.denoise_batch(images, self.model)
images = self.model.denoise(images)
for step, images in enumerate(images[::-1]):
self.tensorboard_manager.log_images(
tag=f"Images at timestep {global_step}",
Expand Down
28 changes: 28 additions & 0 deletions diffusion_models/gaussian_diffusion/base_diffuser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import abc
from typing import List
from typing import Tuple

import torch

from diffusion_models.gaussian_diffusion.beta_schedulers import (
BaseBetaScheduler,
)


class BaseDiffuser(abc.ABC):
def __init__(self, beta_scheduler: BaseBetaScheduler):
self.beta_scheduler = beta_scheduler

@abc.abstractmethod
def diffuse_batch(
self, images: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pass

@abc.abstractmethod
def denoise_batch(self, images: torch.Tensor, model) -> List[torch.Tensor]:
pass

@abc.abstractmethod
def to(self, device: str = "cpu"):
pass
97 changes: 95 additions & 2 deletions diffusion_models/gaussian_diffusion/beta_schedulers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,51 @@
import abc
import logging
from typing import Optional

import torch


class BaseBetaScheduler:
def __init__(self, steps: int):
def __init__(self, steps: int, enforce_zero_terminal_snr: bool = False):
super().__init__()
self.steps = steps
self.betas = self.sample_betas()
self.alpha_bars = self.compute_alpha_bar()

if enforce_zero_terminal_snr:
self.enforce_zero_terminal_snr()

def enforce_zero_terminal_snr(self):
alpha_bar_length = len(self.alpha_bars)

# Convert betas to alphas_bar_sqrt
alphas = 1 - self.betas
alphas_bar = alphas.cumprod(0)
alphas_bar_sqrt = alphas_bar.sqrt()

# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so first timestep is back to old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (
alphas_bar_sqrt_0 - alphas_bar_sqrt_T
)

# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2
alphas = alphas_bar[1:] / alphas_bar[:-1]
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
if len(alphas) == alpha_bar_length:
self.betas = betas
self.alpha_bars = alphas_bar
else:
logging.warning(
"Got different alpha_bar length after enforcing zero SNR. Please check your beta scheduler"
)

@abc.abstractmethod
def sample_betas(self):
pass
Expand All @@ -22,17 +59,31 @@ def to(self, device: str):
self.alpha_bars = self.alpha_bars.to(device)
return self

@classmethod
def from_tensors(
cls, steps: int, betas: torch.Tensor, alpha_bars: torch.Tensor
):
generic_beta_scheduler = cls(0)
generic_beta_scheduler.steps = steps
generic_beta_scheduler.betas = betas
generic_beta_scheduler.alpha_bars = alpha_bars
return generic_beta_scheduler


class LinearBetaScheduler(BaseBetaScheduler):
def __init__(
self,
beta_start: float = 0.0001,
beta_end: float = 0.02,
steps: int = 1000,
enforce_zero_terminal_snr: bool = True,
):
self.beta_start = beta_start
self.beta_end = beta_end
super().__init__(steps)
super().__init__(
steps=steps,
enforce_zero_terminal_snr=enforce_zero_terminal_snr,
)

def sample_betas(self):
return torch.linspace(self.beta_start, self.beta_end, self.steps)
Expand All @@ -41,3 +92,45 @@ def compute_alpha_bar(self):
alphas = 1 - self.betas
alpha_bar = torch.cumprod(alphas, dim=0)
return alpha_bar


class CosineBetaScheduler(BaseBetaScheduler):
def __init__(
self,
offset: float = 0.008,
steps: int = 1000,
max_beta: Optional[float] = 0.999,
):
self.offset = offset
self.max_beta = max_beta
self.steps = steps
self._alpha_bars = self._compute_alpha_bar()
self._betas = self._compute_betas()

super().__init__(
steps=steps,
)

def f(self, t: torch.Tensor):
return (
torch.cos(
(((t / self.steps) + self.offset) / (1 + self.offset)) * (torch.pi / 2)
)
** 2
)

def _compute_betas(self):
betas = 1 - self._alpha_bars[1:] / self._alpha_bars[:-1]
if self.max_beta:
betas = torch.clip(betas, max=self.max_beta)
return betas

def _compute_alpha_bar(self):
t = torch.linspace(0, self.steps, self.steps, dtype=torch.float32)
return self.f(t) / self.f(torch.tensor([0], dtype=torch.float32))

def sample_betas(self):
return self._betas

def compute_alpha_bar(self):
return self._alpha_bars
Loading

0 comments on commit 2401f12

Please sign in to comment.