diff --git a/codes/models/gpt_voice/unet_diffusion_tts9.py b/codes/models/gpt_voice/unet_diffusion_tts9.py index b9520638..427a0f16 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts9.py +++ b/codes/models/gpt_voice/unet_diffusion_tts9.py @@ -139,6 +139,7 @@ class DiffusionTts(nn.Module): in_channels=1, in_latent_channels=1024, in_tokens=8193, + conditioning_expansion=4, out_channels=2, # mean and variance dropout=0, # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K @@ -232,6 +233,7 @@ class DiffusionTts(nn.Module): AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), ) + self.conditioning_expansion = conditioning_expansion self.input_blocks = nn.ModuleList( [ @@ -430,6 +432,7 @@ class DiffusionTts(nn.Module): code_emb) # Everything after this comment is timestep dependent. + code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) first = True @@ -454,6 +457,13 @@ class DiffusionTts(nn.Module): h = h.float() out = self.out(h) + # Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors. + extraneous_addition = 0 + params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) + list(self.code_converter.parameters()) + for p in params: + extraneous_addition = extraneous_addition + p.mean() + out = out + extraneous_addition * 0 + return out[:, :, :orig_x_shape]