Make deterministic sampler work with distributed training & microbatches
This commit is contained in:
parent
77c18b53b3
commit
f87e10ffef
|
@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
|||
import numpy as np
|
||||
import torch as th
|
||||
import torch.distributed as dist
|
||||
from torch import distributed
|
||||
|
||||
|
||||
def create_named_schedule_sampler(name, diffusion):
|
||||
|
@ -69,18 +70,37 @@ class UniformSampler(ScheduleSampler):
|
|||
|
||||
class DeterministicSampler:
|
||||
"""
|
||||
Returns the same equally spread-out sampling schedule every time it is called.
|
||||
Returns the same equally spread-out sampling schedule every time it is called. Automatically handles distributed
|
||||
cases by sharing the load across all entities. reset() must be called once a full batch is completed.
|
||||
"""
|
||||
def __init__(self, diffusion):
|
||||
def __init__(self, diffusion, sampling_range, env):
|
||||
super().__init__()
|
||||
self.timesteps = diffusion.num_timesteps
|
||||
self.rank = max(env['rank'], 0)
|
||||
if distributed.is_initialized():
|
||||
self.world_size = distributed.get_world_size()
|
||||
else:
|
||||
self.world_size = 1
|
||||
# The sampling range gets spread out across multiple distributed entities.
|
||||
rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range
|
||||
self.indices = (rnge * self.timesteps).long()
|
||||
|
||||
def sample(self, batch_size, device):
|
||||
rnge = th.arange(0, batch_size, device=device).float() / batch_size
|
||||
indices = (rnge * self.timesteps).long()
|
||||
"""
|
||||
Iteratively samples across the deterministic range specified by the initialization params.
|
||||
"""
|
||||
assert batch_size < self.indices.shape[0]
|
||||
if self.counter+batch_size > self.indices.shape[0]:
|
||||
print(f"Diffusion DeterministicSampler; Likely error. {self.counter}, {batch_size}, {self.indices.shape[0]}. Did you forget to set the sampling range to your batch size for the deterministic sampler?")
|
||||
self.counter = 0 # Recover by setting to 0.
|
||||
indices = self.indices[self.counter:self.counter+batch_size].to(device)
|
||||
self.counter = self.counter + batch_size
|
||||
weights = th.ones_like(indices).float()
|
||||
return indices, weights
|
||||
|
||||
def reset(self):
|
||||
self.counter = 0
|
||||
|
||||
|
||||
class LossAwareSampler(ScheduleSampler):
|
||||
def update_with_local_losses(self, local_ts, local_losses):
|
||||
|
|
|
@ -27,6 +27,7 @@ class GaussianDiffusionInjector(Injector):
|
|||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||
|
||||
def forward(self, state):
|
||||
gen = self.env['generators'][self.opt['generator']]
|
||||
|
@ -34,15 +35,15 @@ class GaussianDiffusionInjector(Injector):
|
|||
|
||||
with autocast(enabled=self.env['opt']['fp16']):
|
||||
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
|
||||
sampler = DeterministicSampler(self.diffusion)
|
||||
sampler = self.deterministic_sampler
|
||||
else:
|
||||
sampler = self.schedule_sampler
|
||||
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||
if isinstance(sampler, LossAwareSampler):
|
||||
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||
|
||||
if len(self.extra_model_output_keys) > 0:
|
||||
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
||||
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
||||
|
|
Loading…
Reference in New Issue
Block a user