Make deterministic sampler work with distributed training & microbatches

This commit is contained in:
James Betker 2022-03-04 11:50:50 -07:00
parent 77c18b53b3
commit f87e10ffef
2 changed files with 27 additions and 6 deletions

View File

@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
import numpy as np import numpy as np
import torch as th import torch as th
import torch.distributed as dist import torch.distributed as dist
from torch import distributed
def create_named_schedule_sampler(name, diffusion): def create_named_schedule_sampler(name, diffusion):
@ -69,18 +70,37 @@ class UniformSampler(ScheduleSampler):
class DeterministicSampler: class DeterministicSampler:
""" """
Returns the same equally spread-out sampling schedule every time it is called. Returns the same equally spread-out sampling schedule every time it is called. Automatically handles distributed
cases by sharing the load across all entities. reset() must be called once a full batch is completed.
""" """
def __init__(self, diffusion): def __init__(self, diffusion, sampling_range, env):
super().__init__() super().__init__()
self.timesteps = diffusion.num_timesteps self.timesteps = diffusion.num_timesteps
self.rank = max(env['rank'], 0)
if distributed.is_initialized():
self.world_size = distributed.get_world_size()
else:
self.world_size = 1
# The sampling range gets spread out across multiple distributed entities.
rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range
self.indices = (rnge * self.timesteps).long()
def sample(self, batch_size, device): def sample(self, batch_size, device):
rnge = th.arange(0, batch_size, device=device).float() / batch_size """
indices = (rnge * self.timesteps).long() Iteratively samples across the deterministic range specified by the initialization params.
"""
assert batch_size < self.indices.shape[0]
if self.counter+batch_size > self.indices.shape[0]:
print(f"Diffusion DeterministicSampler; Likely error. {self.counter}, {batch_size}, {self.indices.shape[0]}. Did you forget to set the sampling range to your batch size for the deterministic sampler?")
self.counter = 0 # Recover by setting to 0.
indices = self.indices[self.counter:self.counter+batch_size].to(device)
self.counter = self.counter + batch_size
weights = th.ones_like(indices).float() weights = th.ones_like(indices).float()
return indices, weights return indices, weights
def reset(self):
self.counter = 0
class LossAwareSampler(ScheduleSampler): class LossAwareSampler(ScheduleSampler):
def update_with_local_losses(self, local_ts, local_losses): def update_with_local_losses(self, local_ts, local_losses):

View File

@ -27,6 +27,7 @@ class GaussianDiffusionInjector(Injector):
self.model_input_keys = opt_get(opt, ['model_input_keys'], []) self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], []) self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0) self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
@ -34,15 +35,15 @@ class GaussianDiffusionInjector(Injector):
with autocast(enabled=self.env['opt']['fp16']): with autocast(enabled=self.env['opt']['fp16']):
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0): if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
sampler = DeterministicSampler(self.diffusion) sampler = self.deterministic_sampler
else: else:
sampler = self.schedule_sampler sampler = self.schedule_sampler
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
t, weights = sampler.sample(hq.shape[0], hq.device) t, weights = sampler.sample(hq.shape[0], hq.device)
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs) diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
if isinstance(sampler, LossAwareSampler): if isinstance(sampler, LossAwareSampler):
sampler.update_with_local_losses(t, diffusion_outputs['losses']) sampler.update_with_local_losses(t, diffusion_outputs['losses'])
if len(self.extra_model_output_keys) > 0: if len(self.extra_model_output_keys) > 0:
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs'])) assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])} out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}