forked from mrq/DL-Art-School
Add a deterministic timestep sampler, with provisions to employ it every n steps
This commit is contained in:
parent
f490eaeba7
commit
2d1cb83c1d
codes
|
@ -67,6 +67,21 @@ class UniformSampler(ScheduleSampler):
|
||||||
return self._weights
|
return self._weights
|
||||||
|
|
||||||
|
|
||||||
|
class DeterministicSampler:
|
||||||
|
"""
|
||||||
|
Returns the same equally spread-out sampling schedule every time it is called.
|
||||||
|
"""
|
||||||
|
def __init__(self, diffusion):
|
||||||
|
super().__init__()
|
||||||
|
self.timesteps = diffusion.num_timesteps
|
||||||
|
|
||||||
|
def sample(self, batch_size, device):
|
||||||
|
rnge = th.arange(0, batch_size, device=device).float() / batch_size
|
||||||
|
indices = (rnge * self.timesteps).long()
|
||||||
|
weights = th.ones_like(indices).float()
|
||||||
|
return indices, weights
|
||||||
|
|
||||||
|
|
||||||
class LossAwareSampler(ScheduleSampler):
|
class LossAwareSampler(ScheduleSampler):
|
||||||
def update_with_local_losses(self, local_ts, local_losses):
|
def update_with_local_losses(self, local_ts, local_losses):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from torch.cuda.amp import autocast
|
from torch.cuda.amp import autocast
|
||||||
|
|
||||||
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import GaussianDiffusion, get_named_beta_schedule
|
||||||
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler
|
from models.diffusion.resample import create_named_schedule_sampler, LossAwareSampler, DeterministicSampler
|
||||||
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
from models.diffusion.respace import space_timesteps, SpacedDiffusion
|
||||||
from trainer.inject import Injector
|
from trainer.inject import Injector
|
||||||
from utils.util import opt_get
|
from utils.util import opt_get
|
||||||
|
@ -26,22 +26,22 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion)
|
||||||
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
self.model_input_keys = opt_get(opt, ['model_input_keys'], [])
|
||||||
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
self.extra_model_output_keys = opt_get(opt, ['extra_model_output_keys'], [])
|
||||||
|
self.deterministic_timesteps_every = opt_get(opt, ['deterministic_timesteps_every'], 0)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
gen = self.env['generators'][self.opt['generator']]
|
gen = self.env['generators'][self.opt['generator']]
|
||||||
hq = state[self.input]
|
hq = state[self.input]
|
||||||
|
|
||||||
# In eval mode, seed torch with a deterministic seed for reproducibility.
|
|
||||||
if not gen.training:
|
|
||||||
torch.manual_seed(0)
|
|
||||||
random.seed(0)
|
|
||||||
|
|
||||||
with autocast(enabled=self.env['opt']['fp16']):
|
with autocast(enabled=self.env['opt']['fp16']):
|
||||||
|
if not gen.training or (self.deterministic_timesteps_every != 0 and self.env['step'] % self.deterministic_timesteps_every == 0):
|
||||||
|
sampler = DeterministicSampler(self.diffusion)
|
||||||
|
else:
|
||||||
|
sampler = self.schedule_sampler
|
||||||
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
model_inputs = {k: state[v] for k, v in self.model_input_keys.items()}
|
||||||
t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device)
|
t, weights = sampler.sample(hq.shape[0], hq.device)
|
||||||
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
diffusion_outputs = self.diffusion.training_losses(gen, hq, t, model_kwargs=model_inputs)
|
||||||
if isinstance(self.schedule_sampler, LossAwareSampler):
|
if isinstance(sampler, LossAwareSampler):
|
||||||
self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
sampler.update_with_local_losses(t, diffusion_outputs['losses'])
|
||||||
|
|
||||||
if len(self.extra_model_output_keys) > 0:
|
if len(self.extra_model_output_keys) > 0:
|
||||||
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
assert(len(self.extra_model_output_keys) == len(diffusion_outputs['extra_outputs']))
|
||||||
|
@ -52,11 +52,6 @@ class GaussianDiffusionInjector(Injector):
|
||||||
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
self.output_variational_bounds_key: diffusion_outputs['vb'],
|
||||||
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
self.output_x_start_key: diffusion_outputs['x_start_predicted']})
|
||||||
|
|
||||||
# Absolutely critical to undo the above seed.
|
|
||||||
if not gen.training:
|
|
||||||
torch.manual_seed(int(time.time()))
|
|
||||||
random.seed(int(time.time()))
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user