From 3e97abc8a9d72776f7c819c401beddc518ef96bc Mon Sep 17 00:00:00 2001 From: James Betker Date: Fri, 1 Apr 2022 14:38:53 -0600 Subject: [PATCH] update flat0 to break out timestep-independent inference steps --- .../audio/tts/unet_diffusion_tts_flat0.py | 84 ++++++++++++------- codes/utils/util.py | 9 ++ 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat0.py b/codes/models/audio/tts/unet_diffusion_tts_flat0.py index 05f8847e..bf526b27 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat0.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat0.py @@ -187,7 +187,46 @@ class DiffusionTtsFlat(nn.Module): } return groups - def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False, return_code_pred=False): + def timestep_independent(self, aligned_conditioning, conditioning_input, return_code_pred): + # Shuffle aligned_latent to BxCxS format + if is_latent(aligned_conditioning): + aligned_conditioning = aligned_conditioning.permute(0, 2, 1) + + # Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent. + speech_conditioning_input = conditioning_input.unsqueeze(1) if len( + conditioning_input.shape) == 3 else conditioning_input + conds = [] + for j in range(speech_conditioning_input.shape[1]): + conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) + conds = torch.cat(conds, dim=-1) + cond_emb = conds.mean(dim=-1) + cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) + if is_latent(aligned_conditioning): + code_emb = self.latent_converter(aligned_conditioning) + else: + code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1) + code_emb = self.code_converter(code_emb) + code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) + + unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) + # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. + if self.training and self.unconditioned_percentage > 0: + unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), + device=code_emb.device) < self.unconditioned_percentage + code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(aligned_conditioning.shape[0], 1, 1), + code_emb) + expanded_code_emb = F.interpolate(code_emb, size=aligned_conditioning.shape[-1]*4, mode='nearest') + + if not return_code_pred: + return expanded_code_emb + else: + mel_pred = self.mel_head(expanded_code_emb) + # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. + mel_pred = mel_pred * unconditioned_batches.logical_not() + return expanded_code_emb, mel_pred + + + def forward(self, x, timesteps, aligned_conditioning=None, conditioning_input=None, precomputed_aligned_embeddings=None, conditioning_free=False, return_code_pred=False): """ Apply the model to an input batch. @@ -195,49 +234,32 @@ class DiffusionTtsFlat(nn.Module): :param timesteps: a 1-D batch of timesteps. :param aligned_conditioning: an aligned latent or sequence of tokens providing useful data about the sample to be produced. :param conditioning_input: a full-resolution audio clip that is used as a reference to the style you want decoded. + :param precomputed_aligned_embeddings: Embeddings returned from self.timestep_independent() :param conditioning_free: When set, all conditioning inputs (including tokens and conditioning_input) will not be considered. :return: an [N x C x ...] Tensor of outputs. """ - # Shuffle aligned_latent to BxCxS format - if is_latent(aligned_conditioning): - aligned_conditioning = aligned_conditioning.permute(0, 2, 1) + assert precomputed_aligned_embeddings is not None or (aligned_conditioning is not None and conditioning_input is not None) + assert not (return_code_pred and precomputed_aligned_embeddings is not None) # These two are mutually exclusive. - # Note: this block does not need to repeated on inference, since it is not timestep-dependent or x-dependent. unused_params = [] if conditioning_free: - code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1) + code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1]) + unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) + unused_params.extend(list(self.latent_converter.parameters())) else: + if precomputed_aligned_embeddings is not None: + code_emb = precomputed_aligned_embeddings + else: + code_emb, mel_pred = self.timestep_independent(aligned_conditioning, conditioning_input, True) + unused_params.append(self.unconditioned_embedding) - speech_conditioning_input = conditioning_input.unsqueeze(1) if len(conditioning_input.shape) == 3 else conditioning_input - conds = [] - for j in range(speech_conditioning_input.shape[1]): - conds.append(self.contextual_embedder(speech_conditioning_input[:, j])) - conds = torch.cat(conds, dim=-1) - cond_emb = conds.mean(dim=-1) - cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1) if is_latent(aligned_conditioning): - code_emb = self.latent_converter(aligned_conditioning) unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters())) else: - code_emb = self.code_embedding(aligned_conditioning).permute(0,2,1) - code_emb = self.code_converter(code_emb) unused_params.extend(list(self.latent_converter.parameters())) - code_emb = self.code_norm(code_emb) * (1 + cond_scale.unsqueeze(-1)) + cond_shift.unsqueeze(-1) - unconditioned_batches = torch.zeros((code_emb.shape[0], 1, 1), device=code_emb.device) - # Mask out the conditioning branch for whole batch elements, implementing something similar to classifier-free guidance. - if self.training and self.unconditioned_percentage > 0: - unconditioned_batches = torch.rand((code_emb.shape[0], 1, 1), - device=code_emb.device) < self.unconditioned_percentage - code_emb = torch.where(unconditioned_batches, self.unconditioned_embedding.repeat(x.shape[0], 1, 1), - code_emb) - expanded_code_emb = F.interpolate(code_emb, size=x.shape[-1], mode='nearest') - mel_pred = self.mel_head(expanded_code_emb) - # Multiply mel_pred by !unconditioned_branches, which drops the gradient on unconditioned branches. This is because we don't want that gradient being used to train parameters through the codes_embedder as it unbalances contributions to that network from the MSE loss. - mel_pred = mel_pred * unconditioned_batches.logical_not() - # Everything after this comment is timestep dependent. time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) - code_emb = self.conditioning_timestep_integrator(expanded_code_emb, time_emb) + code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) x = self.inp_block(x) x = torch.cat([x, code_emb], dim=1) x = self.integrating_conv(x) @@ -272,7 +294,7 @@ def register_diffusion_tts_flat0(opt_net, opt): if __name__ == '__main__': clip = torch.randn(2, 100, 400) aligned_latent = torch.randn(2,388,512) - aligned_sequence = torch.randint(0,8192,(2,388)) + aligned_sequence = torch.randint(0,8192,(2,100)) cond = torch.randn(2, 100, 400) ts = torch.LongTensor([600, 600]) model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5) diff --git a/codes/utils/util.py b/codes/utils/util.py index 5e32832d..bbd61533 100644 --- a/codes/utils/util.py +++ b/codes/utils/util.py @@ -591,6 +591,15 @@ def load_audio(audiopath, sampling_rate, raw_data=None): return audio +def pad_or_truncate(t, length): + if t.shape[-1] == length: + return t + elif t.shape[-1] < length: + return F.pad(t, (0, length-t.shape[-1])) + else: + return t[..., :length] + + def load_wav_to_torch(full_path): import scipy.io.wavfile sampling_rate, data = scipy.io.wavfile.read(full_path)