From 96e90e7047a2f56e75fad9392424b1963438f801 Mon Sep 17 00:00:00 2001 From: James Betker Date: Mon, 26 Jul 2021 16:27:31 -0600 Subject: [PATCH] Add support for a gaussian-diffusion-based wave tacotron --- codes/data/audio/nv_tacotron_dataset.py | 22 +- codes/models/diffusion/gaussian_diffusion.py | 75 ++++++ codes/models/diffusion/respace.py | 27 +- codes/models/diffusion/unet_diffusion.py | 9 +- codes/models/tacotron2/tacotron2.py | 9 +- codes/models/tacotron2/wave_tacotron.py | 230 ++++++++++++++++++ codes/train.py | 2 +- codes/trainer/injectors/base_injectors.py | 9 + .../injectors/gaussian_diffusion_injector.py | 33 +++ codes/trainer/losses.py | 18 +- 10 files changed, 414 insertions(+), 20 deletions(-) create mode 100644 codes/models/tacotron2/wave_tacotron.py diff --git a/codes/data/audio/nv_tacotron_dataset.py b/codes/data/audio/nv_tacotron_dataset.py index dbd010e7..b4060f2e 100644 --- a/codes/data/audio/nv_tacotron_dataset.py +++ b/codes/data/audio/nv_tacotron_dataset.py @@ -23,6 +23,8 @@ class TextMelLoader(torch.utils.data.Dataset): self.max_wav_value = hparams.max_wav_value self.sampling_rate = hparams.sampling_rate self.load_mel_from_disk = hparams.load_mel_from_disk + self.return_wavs = hparams.return_wavs + assert not (self.load_mel_from_disk and self.return_wavs) self.stft = layers.TacotronSTFT( hparams.filter_length, hparams.hop_length, hparams.win_length, hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin, @@ -47,8 +49,11 @@ class TextMelLoader(torch.utils.data.Dataset): audio_norm = audio / self.max_wav_value audio_norm = audio_norm.unsqueeze(0) audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) - melspec = self.stft.mel_spectrogram(audio_norm) - melspec = torch.squeeze(melspec, 0) + if self.return_wavs: + melspec = audio_norm + else: + melspec = self.stft.mel_spectrogram(audio_norm) + melspec = torch.squeeze(melspec, 0) else: melspec = torch.from_numpy(np.load(filename)) assert melspec.size(0) == self.stft.n_mel_channels, ( @@ -124,13 +129,18 @@ if __name__ == '__main__': params = { 'mode': 'nv_tacotron', 'path': 'E:\\audio\\LJSpeech-1.1\\ljs_audio_text_train_filelist.txt', - + 'phase': 'train', + 'n_workers': 0, + 'batch_size': 2, + 'return_wavs': True, } - from data import create_dataset - ds = create_dataset(params) + from data import create_dataset, create_dataloader + + ds, c = create_dataset(params, return_collate=True) + dl = create_dataloader(ds, params, collate_fn=c) i = 0 m = [] - for b in ds: + for b in dl: m.append(b) i += 1 if i > 9999: diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 2b57f34e..0f2cb8c5 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -817,6 +817,81 @@ class GaussianDiffusion: return terms + def autoregressive_training_losses(self, model, x_start, t, model_output_keys, gd_out_key, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + x_tn1 = self.q_sample(x_start, t-1, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + assert False # not currently supported for this type of diffusion. + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + model_outputs = model(x_t, x_tn1, self._scale_timesteps(t), **model_kwargs) + terms.update({k: o for k, o in zip(model_output_keys, model_outputs)}) + model_output = terms[gd_out_key] + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, C = x_t.shape[:2] + assert model_output.shape == (B, C, 2, *x_t.shape[2:]) + model_output, model_var_values = model_output[:, :, 0], model_output[:, :, 1] + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + if self.model_mean_type == ModelMeanType.PREVIOUS_X: + target = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0] + x_start_pred = torch.zeros(x_start) # Not supported. + elif self.model_mean_type == ModelMeanType.START_X: + target = x_start + x_start_pred = model_output + elif self.model_mean_type == ModelMeanType.EPSILON: + target = noise + x_start_pred = self._predict_xstart_from_eps(x_t, t, model_output) + else: + raise NotImplementedError(self.model_mean_type) + assert model_output.shape == target.shape == x_start.shape + terms["mse"] = mean_flat((target - model_output) ** 2) + terms["x_start_predicted"] = x_start_pred + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + def _prior_bpd(self, x_start): """ Get the prior KL term for the variational lower-bound, measured in diff --git a/codes/models/diffusion/respace.py b/codes/models/diffusion/respace.py index b568817e..4fad2f8b 100644 --- a/codes/models/diffusion/respace.py +++ b/codes/models/diffusion/respace.py @@ -95,16 +95,22 @@ class SpacedDiffusion(GaussianDiffusion): ): # pylint: disable=signature-differs return super().training_losses(self._wrap_model(model), *args, **kwargs) + def autoregressive_training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().autoregressive_training_losses(self._wrap_model(model, True), *args, **kwargs) + def condition_mean(self, cond_fn, *args, **kwargs): return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) def condition_score(self, cond_fn, *args, **kwargs): return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) - def _wrap_model(self, model): - if isinstance(model, _WrappedModel): + def _wrap_model(self, model, autoregressive=False): + if isinstance(model, _WrappedModel) or isinstance(model, _WrappedAutoregressiveModel): return model - return _WrappedModel( + mod = _WrappedAutoregressiveModel if autoregressive else _WrappedModel + return mod( model, self.timestep_map, self.rescale_timesteps, self.original_num_steps ) @@ -126,3 +132,18 @@ class _WrappedModel: if self.rescale_timesteps: new_ts = new_ts.float() * (1000.0 / self.original_num_steps) return self.model(x, new_ts, **kwargs) + + +class _WrappedAutoregressiveModel: + def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): + self.model = model + self.timestep_map = timestep_map + self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, x0, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + if self.rescale_timesteps: + new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, x0, new_ts, **kwargs) \ No newline at end of file diff --git a/codes/models/diffusion/unet_diffusion.py b/codes/models/diffusion/unet_diffusion.py index cce553e9..55c5dca8 100644 --- a/codes/models/diffusion/unet_diffusion.py +++ b/codes/models/diffusion/unet_diffusion.py @@ -441,6 +441,7 @@ class UNetModel(nn.Module): use_scale_shift_norm=False, resblock_updown=False, use_new_attention_order=False, + use_raw_y_as_embedding=False, ): super().__init__() @@ -471,6 +472,8 @@ class UNetModel(nn.Module): if self.num_classes is not None: self.label_emb = nn.Embedding(num_classes, time_embed_dim) + self.use_raw_y_as_embedding = use_raw_y_as_embedding + assert (self.num_classes is not None) != use_raw_y_as_embedding # These are mutually-exclusive. self.input_blocks = nn.ModuleList( [ @@ -630,16 +633,14 @@ class UNetModel(nn.Module): :param y: an [N] Tensor of labels, if class-conditional. :return: an [N x C x ...] Tensor of outputs. """ - assert (y is not None) == ( - self.num_classes is not None - ), "must specify y if and only if the model is class-conditional" - hs = [] emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) if self.num_classes is not None: assert y.shape == (x.shape[0],) emb = emb + self.label_emb(y) + if self.use_raw_y_as_embedding: + emb = emb + y h = x.type(self.dtype) for module in self.input_blocks: diff --git a/codes/models/tacotron2/tacotron2.py b/codes/models/tacotron2/tacotron2.py index bed404b2..2ea8b06e 100644 --- a/codes/models/tacotron2/tacotron2.py +++ b/codes/models/tacotron2/tacotron2.py @@ -32,7 +32,7 @@ class LocationLayer(nn.Module): class Attention(nn.Module): def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, - attention_location_n_filters, attention_location_kernel_size): + attention_location_n_filters=32, attention_location_kernel_size=31): super(Attention, self).__init__() self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, bias=False, w_init_gain='tanh') @@ -528,6 +528,9 @@ def register_nv_tacotron2(opt_net, opt): if __name__ == '__main__': tron = register_nv_tacotron2({}, {}) - inputs = torch.randint(high=24, size=(1,12)), torch.tensor([12]), torch.randn((1,80,749)), 800, torch.tensor([749]) - out = tron(inputs) + inputs = torch.randint(high=24, size=(1,12)), \ + torch.tensor([12]), \ + torch.randn((1,80,749)), \ + torch.tensor([749]) + out = tron(*inputs) print(out) \ No newline at end of file diff --git a/codes/models/tacotron2/wave_tacotron.py b/codes/models/tacotron2/wave_tacotron.py new file mode 100644 index 00000000..e4d802eb --- /dev/null +++ b/codes/models/tacotron2/wave_tacotron.py @@ -0,0 +1,230 @@ +from math import sqrt +import torch +from munch import munchify +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F + +from models.diffusion.unet_diffusion import UNetModel +from models.tacotron2.layers import ConvNorm, LinearNorm +from models.tacotron2.hparams import create_hparams +from models.tacotron2.tacotron2 import Prenet, Attention, Encoder +from trainer.networks import register_model +from models.tacotron2.taco_utils import get_mask_from_lengths +from utils.util import opt_get, checkpoint + + +class WavDecoder(nn.Module): + def __init__(self, dec_channels, K_ms=40, sample_rate=24000, dropout_probability=.1): + super().__init__() + self.dec_channels = dec_channels + self.K = int(sample_rate * (K_ms/1000)) # 960 with the defaults + self.clarifier = UNetModel(image_size=self.K, + in_channels=1, + model_channels=dec_channels // 4, # This is a requirement to enable to load the embedding produced by the decoder into the unet model. + out_channels=2, # 2 channels: eps_pred and variance_pred + num_res_blocks=2, + attention_resolutions=(16,32), + dims=1, + dropout=.1, + channel_mult=(1,1,1,2,4,8), + use_raw_y_as_embedding=True) + assert self.K % 64 == 0 # Otherwise the UNetModel breaks. + self.pre_rnn = Prenet(self.K, [dec_channels, dec_channels]) + self.attention_rnn = nn.LSTMCell(dec_channels*2, dec_channels) + self.attention_layer = Attention(dec_channels, dec_channels, dec_channels) + self.decoder_rnn = nn.LSTMCell(dec_channels*2, dec_channels, 1) + self.linear_projection = LinearNorm(dec_channels*2, self.dec_channels) + self.gate_layer = LinearNorm(self.dec_channels*2, 1, bias=True, w_init_gain='sigmoid') + self.dropout_probability = dropout_probability + + def chunk_wav(self, wav): + wavs = list(torch.split(wav, self.K, dim=-1)) + # Pad the last chunk as needed. + padding_needed = self.K - wavs[-1].shape[-1] + if padding_needed > 0: + wavs[-1] = F.pad(wavs[-1], (0,padding_needed)) + + wavs = torch.stack(wavs, dim=1) # wavs.shape = (b,s,K) where s=decoder sequence length + return wavs, padding_needed + + def prepare_decoder_inputs(self, inp): + # inp.shape = (b,s,K) chunked waveform. + b,s,K = inp.shape + first_frame = torch.zeros(b,1,K).to(inp.device) + x = torch.cat([first_frame, inp[:,:-1]], dim=1) # It is now aligned for teacher forcing. + return x + + def initialize_decoder_states(self, memory, mask): + """ Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory + PARAMS + ------ + memory: Encoder outputs + mask: Mask for padded data if training, expects None for inference + """ + B = memory.size(0) + MAX_TIME = memory.size(1) + + self.attention_hidden = Variable(memory.data.new(B, self.dec_channels).zero_()) + self.attention_cell = Variable(memory.data.new(B, self.dec_channels).zero_()) + + self.decoder_hidden = Variable(memory.data.new(B, self.dec_channels).zero_()) + self.decoder_cell = Variable(memory.data.new(B, self.dec_channels).zero_()) + + self.attention_weights = Variable(memory.data.new(B, MAX_TIME).zero_()) + self.attention_weights_cum = Variable(memory.data.new(B, MAX_TIME).zero_()) + self.attention_context = Variable(memory.data.new(B, self.dec_channels).zero_()) + + self.memory = memory + self.processed_memory = checkpoint(self.attention_layer.memory_layer, memory) + self.mask = mask + + def teardown_states(self): + self.attention_hidden = None + self.attention_cell = None + self.decoder_hidden = None + self.decoder_cell = None + self.attention_weights = None + self.attention_weights_cum = None + self.attention_context = None + self.memory = None + self.processed_memory = None + + def produce_context(self, decoder_input): + """ Produces a context and a stop token prediction using the built-in RNN. + PARAMS + ------ + decoder_input: prior diffusion step that has been resolved. + + RETURNS + ------- + mel_output: + gate_output: gate output energies + attention_weights: + """ + cell_input = torch.cat((decoder_input, self.attention_context), -1) + self.attention_hidden, self.attention_cell = self.attention_rnn(cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout(self.attention_hidden, self.dropout_probability, self.training) + + attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1) + self.attention_context, self.attention_weights = checkpoint(self.attention_layer, self.attention_hidden, self.memory, + self.processed_memory, attention_weights_cat, self.mask) + + self.attention_weights_cum += self.attention_weights + decoder_input = torch.cat((self.attention_hidden, self.attention_context), -1) + self.decoder_hidden, self.decoder_cell = self.decoder_rnn(decoder_input, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout(self.decoder_hidden, self.dropout_probability, self.training) + + decoder_hidden_attention_context = torch.cat((self.decoder_hidden, self.attention_context), dim=1) + decoder_output = checkpoint(self.linear_projection, decoder_hidden_attention_context) + + gate_prediction = self.gate_layer(decoder_hidden_attention_context) + return decoder_output, gate_prediction, self.attention_weights + + def recombine(self, diffusion_eps, gate_outputs, alignments, padding_added): + # (T_out, B) -> (B, T_out) + alignments = torch.stack(alignments, dim=1).repeat(1, self.K, 1) + # (T_out, B) -> (B, T_out) + gate_outputs = torch.stack(gate_outputs, dim=1).repeat(1, self.K) + + b,s,_,K = diffusion_eps.shape + # (B, S, 2, K) -> (B, 2, S*K) + diffusion_eps = diffusion_eps.permute(0,2,1,3).reshape(b, 2, s*K) + + return diffusion_eps[:,:,:-padding_added], gate_outputs[:,:-padding_added], alignments[:,:-padding_added] + + def forward(self, wav, wav_corrected, timesteps, text_enc, memory_lengths): + ''' + Performs a training forward pass with the given data. + :param wav: (b,n) diffused waveform tensor on the interval [-1,1] + :param wav_corrected: (b,n) waveform tensor that has had one step of diffusion correction over + :param text_enc: (b,e) embedding post-encoder with e=self.dec_channels + ''' + + # Start by splitting up the provided waveforms into discrete segments. + wavs, padding_added = self.chunk_wav(wav) + wavs_corrected, _ = self.chunk_wav(wav_corrected) + wavs_corrected = self.prepare_decoder_inputs(wavs_corrected) + wavs_corrected = checkpoint(self.pre_rnn, wavs_corrected) + + self.initialize_decoder_states(text_enc, mask=~get_mask_from_lengths(memory_lengths)) + decoder_contexts, gate_outputs, alignments = [], [], [] + while len(decoder_contexts) < wavs_corrected.size(1): + decoder_input = wavs_corrected[:, len(decoder_contexts)] + dec_context, gate_output, attention_weights = self.produce_context(decoder_input) + decoder_contexts += [dec_context.squeeze(1)] + gate_outputs += [gate_output.squeeze(1)] + alignments += [attention_weights] + self.teardown_states() + + # diffusion_inputs and wavs needs to have the sequence and batch dimensions combined, and needs a channel dimension + diffusion_emb = torch.stack(decoder_contexts, dim=1) + b,s,c = diffusion_emb.shape + diffusion_emb = diffusion_emb.reshape(b*s,c) + wavs = wavs.reshape(b*s,1,self.K) + diffusion_eps = self.clarifier(wavs, timesteps.repeat(s), diffusion_emb).reshape(b,s,2,self.K) + # Recombine diffusion outputs across the sequence into a single prediction. + diffusion_eps, gate_outputs, alignments = self.recombine(diffusion_eps, gate_outputs, alignments, padding_added) + return diffusion_eps, gate_outputs, alignments + + +class WaveTacotron2(nn.Module): + def __init__(self, hparams): + super().__init__() + self.mask_padding = hparams.mask_padding + self.fp16_run = hparams.fp16_run + self.n_mel_channels = hparams.n_mel_channels + self.n_frames_per_step = hparams.n_frames_per_step + self.embedding = nn.Embedding( + hparams.n_symbols, hparams.symbols_embedding_dim) + std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) + self.encoder = Encoder(hparams) + self.decoder = WavDecoder(hparams.encoder_embedding_dim) + + def parse_output(self, outputs, output_lengths=None): + if self.mask_padding and output_lengths is not None: + mask_fill = outputs[0].shape[-1] + mask = ~get_mask_from_lengths(output_lengths, mask_fill) + mask = mask.expand(mask.size(0), 2, mask.size(1)) + + outputs[0].data.masked_fill_(mask, 0.0) + outputs[0] = outputs[0].unsqueeze(1) # Re-add channel dimension. + outputs[1].data.masked_fill_(mask[:,0], 1e3) # gate energies + + return outputs + + def forward(self, wavs_diffused, wavs_corrected, timesteps, text_inputs, text_lengths, output_lengths): + # Squeeze the channel dimension out of the input wavs - we only handle single-channel audio here. + wavs_diffused = wavs_diffused.squeeze(dim=1) + wavs_corrected = wavs_corrected.squeeze(dim=1) + + text_lengths, output_lengths = text_lengths.data, output_lengths.data + embedded_inputs = self.embedding(text_inputs).transpose(1, 2) + encoder_outputs = checkpoint(self.encoder, embedded_inputs, text_lengths) + eps_pred, gate_outputs, alignments = self.decoder( + wavs_diffused, wavs_corrected, timesteps, encoder_outputs, memory_lengths=text_lengths) + + return self.parse_output([eps_pred, gate_outputs, alignments], output_lengths) + + +@register_model +def register_diffusion_wavetron(opt_net, opt): + hparams = create_hparams() + hparams.update(opt_net) + hparams = munchify(hparams) + return WaveTacotron2(hparams) + + +if __name__ == '__main__': + tron = register_diffusion_wavetron({}, {}) + out = tron(wavs_diffused=torch.randn(2, 1, 22000), + wavs_corrected=torch.randn(2, 1, 22000), + timesteps=torch.LongTensor([555, 543]), + text_inputs=torch.randint(high=24, size=(2,12)), + text_lengths=torch.tensor([12, 12]), + output_lengths=torch.tensor([21995])) + print([o.shape for o in out]) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 35e57a2d..b22bf7f5 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vqvae_xform_audio_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_wave_tacotron_diffusion_lj.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/injectors/base_injectors.py b/codes/trainer/injectors/base_injectors.py index db191866..12fc126b 100644 --- a/codes/trainer/injectors/base_injectors.py +++ b/codes/trainer/injectors/base_injectors.py @@ -10,6 +10,15 @@ from utils.util import opt_get from utils.weight_scheduler import get_scheduler_for_opt +class SqueezeInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.dim = opt['dim'] + + def forward(self, state): + return {self.output: state[self.input].squeeze(dim=self.dim)} + + # Uses a generator to synthesize an image from [in] and injects the results into [out] # Note that results are *not* detached. class GeneratorInjector(Injector): diff --git a/codes/trainer/injectors/gaussian_diffusion_injector.py b/codes/trainer/injectors/gaussian_diffusion_injector.py index 69efda92..2d4969d9 100644 --- a/codes/trainer/injectors/gaussian_diffusion_injector.py +++ b/codes/trainer/injectors/gaussian_diffusion_injector.py @@ -35,6 +35,39 @@ class GaussianDiffusionInjector(Injector): self.output_x_start_key: diffusion_outputs['x_start_predicted']} +class AutoregressiveGaussianDiffusionInjector(Injector): + def __init__(self, opt, env): + super().__init__(opt, env) + self.generator = opt['generator'] + self.output_variational_bounds_key = opt['out_key_vb_loss'] + self.output_x_start_key = opt['out_key_x_start'] + opt['diffusion_args']['betas'] = get_named_beta_schedule(**opt['beta_schedule']) + opt['diffusion_args']['use_timesteps'] = space_timesteps(opt['beta_schedule']['num_diffusion_timesteps'], + [opt['beta_schedule']['num_diffusion_timesteps']]) + self.diffusion = SpacedDiffusion(**opt['diffusion_args']) + self.schedule_sampler = create_named_schedule_sampler(opt['sampler_type'], self.diffusion) + self.model_input_keys = opt_get(opt, ['model_input_keys'], []) + self.model_output_keys = opt['model_output_keys'] + self.model_eps_pred_key = opt['prediction_key'] + + def forward(self, state): + gen = self.env['generators'][self.opt['generator']] + hq = state[self.input] + model_inputs = {k: state[v] for k, v in self.model_input_keys.items()} + t, weights = self.schedule_sampler.sample(hq.shape[0], hq.device) + diffusion_outputs = self.diffusion.autoregressive_training_losses(gen, hq, t, self.model_output_keys, + self.model_eps_pred_key, + model_kwargs=model_inputs) + if isinstance(self.schedule_sampler, LossAwareSampler): + self.schedule_sampler.update_with_local_losses(t, diffusion_outputs['losses']) + outputs = {k: diffusion_outputs[k] for k in self.model_output_keys} + outputs.update({self.output: diffusion_outputs['mse'], + self.output_variational_bounds_key: diffusion_outputs['vb'], + self.output_x_start_key: diffusion_outputs['x_start_predicted']}) + return outputs + + + # Performs inference using a network trained to predict a reverse diffusion process, which nets a image. class GaussianDiffusionInferenceInjector(Injector): def __init__(self, opt, env): diff --git a/codes/trainer/losses.py b/codes/trainer/losses.py index 7bd34136..dc434438 100644 --- a/codes/trainer/losses.py +++ b/codes/trainer/losses.py @@ -121,7 +121,13 @@ class CrossEntropy(ConfigurableLoss): def __init__(self, opt, env): super().__init__(opt, env) self.opt = opt - self.ce = nn.CrossEntropyLoss() + self.subtype = opt_get(opt, ['subtype'], 'ce') + if self.subtype == 'ce': + self.ce = nn.CrossEntropyLoss() + elif self.subtype == 'bce': + self.ce = nn.BCEWithLogitsLoss() + else: + assert False def forward(self, _, state): logits = state[self.opt['logits']] @@ -135,8 +141,14 @@ class CrossEntropy(ConfigurableLoss): logits = logits * mask if self.opt['swap_channels']: logits = logits.permute(0,2,3,1).contiguous() - assert labels.max()+1 <= logits.shape[-1] - return self.ce(logits.view(-1, logits.size(-1)), labels.view(-1)) + if self.subtype == 'bce': + logits = logits.reshape(-1, 1) + labels = labels.reshape(-1, 1) + else: + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + assert labels.max()+1 <= logits.shape[-1] + return self.ce(logits, labels) class PixLoss(ConfigurableLoss):