diff --git a/codes/models/diffusion/resample.py b/codes/models/diffusion/resample.py index c59a1cc5..694aa3f0 100644 --- a/codes/models/diffusion/resample.py +++ b/codes/models/diffusion/resample.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod import numpy as np import torch as th import torch.distributed as dist +from torch import distributed def create_named_schedule_sampler(name, diffusion): @@ -69,18 +70,37 @@ class UniformSampler(ScheduleSampler): class DeterministicSampler: """ - Returns the same equally spread-out sampling schedule every time it is called. + Returns the same equally spread-out sampling schedule every time it is called. Automatically handles distributed + cases by sharing the load across all entities. reset() must be called once a full batch is completed. """ - def __init__(self, diffusion): + def __init__(self, diffusion, sampling_range, env): super().__init__() self.timesteps = diffusion.num_timesteps + self.rank = max(env['rank'], 0) + if distributed.is_initialized(): + self.world_size = distributed.get_world_size() + else: + self.world_size = 1 + # The sampling range gets spread out across multiple distributed entities. + rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range + self.indices = (rnge * self.timesteps).long() def sample(self, batch_size, device): - rnge = th.arange(0, batch_size, device=device).float() / batch_size - indices = (rnge * self.timesteps).long() + """ + Iteratively samples across the deterministic range specified by the initialization params. + """ + assert batch_size < self.indices.shape[0] + if self.counter+batch_size > self.indices.shape[0]: + print(f"Diffusion DeterministicSampler; Likely error. {self.counter}, {batch_size}, {self.indices.shape[0]}. Did you forget to set the sampling range to your batch size for the deterministic sampler?") + self.counter = 0 # Recover by setting to 0. + indices = self.indices[self.counter:self.counter+batch_size].to(device) + self.counter = self.counter + batch_size weights = th.ones_like(indices).float() return indices, weights + def reset(self): + self.counter = 0 + class LossAwareSampler(ScheduleSampler): def update_with_local_losses(self, local_ts, local_losses): diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index f93c363e..45988323 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -27,6 +27,7 @@ class GaussianDiffusionInjector(Injector): self.model_input_keys = opt_get(opt, ['model_input_keys'], []) self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) + self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env) def forward(self, state): gen = self.env['generators'][self.opt['generator']] @@ -34,15 +35,15 @@ class GaussianDiffusionInjector(Injector): with autocast(enabled=self.env['opt']['fp16']): if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0): - sampler = DeterministicSampler(self.diffusion) + sampler = self.deterministic_sampler else: sampler = self.schedule_sampler + self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically. model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} t, weights = sampler.sample(hq.shape[0], hq.device) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) if isinstance(sampler, LossAwareSampler): sampler.update_with_local_losses(t, diffusion_outputs['losses']) - if len(self.extra_model_output_keys) > 0: assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs'])) out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}