Make deterministic sampler work with distributed training & microbatches
This commit is contained in:
parent
77c18b53b3
commit
f87e10ffef
|
@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch as th
|
import torch as th
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
from torch import distributed
|
||||||
|
|
||||||
|
|
||||||
def create_named_schedule_sampler(name, diffusion):
|
def create_named_schedule_sampler(name, diffusion):
|
||||||
|
@ -69,18 +70,37 @@ class UniformSampler(ScheduleSampler):
|
||||||
|
|
||||||
class DeterministicSampler:
|
class DeterministicSampler:
|
||||||
"""
|
"""
|
||||||
Returns the same equally spread-out sampling schedule every time it is called.
|
Returns the same equally spread-out sampling schedule every time it is called. Automatically handles distributed
|
||||||
|
cases by sharing the load across all entities. reset() must be called once a full batch is completed.
|
||||||
"""
|
"""
|
||||||
def __init__(self, diffusion):
|
def __init__(self, diffusion, sampling_range, env):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.timesteps = diffusion.num_timesteps
|
self.timesteps = diffusion.num_timesteps
|
||||||
|
self.rank = max(env['rank'], 0)
|
||||||
|
if distributed.is_initialized():
|
||||||
|
self.world_size = distributed.get_world_size()
|
||||||
|
else:
|
||||||
|
self.world_size = 1
|
||||||
|
# The sampling range gets spread out across multiple distributed entities.
|
||||||
|
rnge = th.arange(0, sampling_range, step=self.world_size).float() / sampling_range
|
||||||
|
self.indices = (rnge * self.timesteps).long()
|
||||||
|
|
||||||
def sample(self, batch_size, device):
|
def sample(self, batch_size, device):
|
||||||
rnge = th.arange(0, batch_size, device=device).float() / batch_size
|
"""
|
||||||
indices = (rnge * self.timesteps).long()
|
Iteratively samples across the deterministic range specified by the initialization params.
|
||||||
|
"""
|
||||||
|
assert batch_size < self.indices.shape[0]
|
||||||
|
if self.counter+batch_size > self.indices.shape[0]:
|
||||||
|
print(f"Diffusion DeterministicSampler; Likely error. {self.counter}, {batch_size}, {self.indices.shape[0]}. Did you forget to set the sampling range to your batch size for the deterministic sampler?")
|
||||||
|
self.counter = 0 # Recover by setting to 0.
|
||||||
|
indices = self.indices[self.counter:self.counter+batch_size].to(device)
|
||||||
|
self.counter = self.counter + batch_size
|
||||||
weights = th.ones_like(indices).float()
|
weights = th.ones_like(indices).float()
|
||||||
return indices, weights
|
return indices, weights
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.counter = 0
|
||||||
|
|
||||||
|
|
||||||
class LossAwareSampler(ScheduleSampler):
|
class LossAwareSampler(ScheduleSampler):
|
||||||
def update_with_local_losses(self, local_ts, local_losses):
|
def update_with_local_losses(self, local_ts, local_losses):
|
||||||
|
|
|
@ -27,6 +27,7 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||||
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||||
|
self.deterministic_sampler = DeterministicSampler(self.diffusion, opt_get(opt, ['deterministic_sampler_expected_batch_size'], 2048), env)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
|
@ -34,15 +35,15 @@ class GaussianDiffusionInjector(Injector):
|
||||||
|
|
||||||
with autocast(enabled=self.env['opt']['fp16']):
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
|
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
|
||||||
sampler = DeterministicSampler(self.diffusion)
|
sampler = self.deterministic_sampler
|
||||||
else:
|
else:
|
||||||
sampler = self.schedule_sampler
|
sampler = self.schedule_sampler
|
||||||
|
self.deterministic_sampler.reset() # Keep this reset whenever it is not being used, so it is ready to use automatically.
|
||||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||||
t, weights = sampler.sample(hq.shape[0], hq.device)
|
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||||
if isinstance(sampler, LossAwareSampler):
|
if isinstance(sampler, LossAwareSampler):
|
||||||
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||||
|
|
||||||
if len(self.extra_model_output_keys) > 0:
|
if len(self.extra_model_output_keys) > 0:
|
||||||
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
||||||
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
out = {k: v for k, v in zip(self.extra_model_output_keys, diffusion_outputs['extra_outputs'])}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user