Make deterministic sampler work with distributed training & microbatches

This commit is contained in:
James Betker 2022-03-04 11:50:50 -07:00
parent 77c18b53b3
commit f87e10ffef
2 changed files with 27 additions and 6 deletions

View File

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
import numpy as np
import torch as th
import torch.distributed as dist
from torch import distributed
def create_named_schedule_sampler(name, diffusion):
@ -69,18 +70,37 @@ class UniformSampler(ScheduleSampler):
class DeterministicSampler:
"""
Returns the same equally spread-out sampling schedule every time it is called.
Returns the same equally spread-out sampling schedule every time it is called. Automatically handles distributed
cases by sharing the load across all entities. reset() must be called once a full batch is completed.
"""
def __init__(self, diffusion):
def __init__(self, diffusion, sampling_range, env):
super().__init__()
self.timesteps = diffusion.num_timesteps
self.rank = max(env['rank'], 0)
if distributed.is_initialized():
self.world_size = distributed.get_world_size()
else:
self.world_size = 1
# The sampling range gets spread out across multiple distributed entities.
rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range
self.indices = (rnge * self.timesteps).long()
def sample(self, batch_size, device):
rnge = th.arange(0, batch_size, device=device).float() / batch_size
indices = (rnge * self.timesteps).long()
"""
Iteratively samples across the deterministic range specified by the initialization params.
"""
assert batch_size < self.indices.shape[0]
if self.counter+batch_size > self.indices.shape[0]:
print(f"Diffusion DeterministicSampler; Likely error. {self.counter}, {batch_size}, {self.indices.shape[0]}. Did you forget to set the sampling range to your batch size for the deterministic sampler?")
self.counter = 0 # Recover by setting to 0.
indices = self.indices[self.counter:self.counter+batch_size].to(device)
self.counter = self.counter + batch_size
weights = th.ones_like(indices).float()
return indices, weights
def reset(self):
self.counter = 0
class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses):

View File

@ -27,6 +27,7 @@ class GaussianDiffusionInjector(Injector):
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
def forward(self, state):
gen = self.env['generators'][self.opt['generator']]
@ -34,15 +35,15 @@ class GaussianDiffusionInjector(Injector):
with autocast(enabled=self.env['opt']['fp16']):
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
sampler = DeterministicSampler(self.diffusion)
sampler = self.deterministic_sampler
else:
sampler = self.schedule_sampler
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
t, weights = sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
if isinstance(sampler, LossAwareSampler):
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
if len(self.extra_model_output_keys) > 0:
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}