diff --git a/codes/models/audio/tts/transformer_diffusion_tts.py b/codes/models/audio/tts/transformer_diffusion_tts.py index 5bde0bb5..9e14a7a7 100644 --- a/codes/models/audio/tts/transformer_diffusion_tts.py +++ b/codes/models/audio/tts/transformer_diffusion_tts.py @@ -64,6 +64,7 @@ class TransformerDiffusion(nn.Module): rotary_emb_dim=32, token_count=8, in_groups=None, + types=2, out_channels=512, # mean and variance dropout=0, use_fp16=False, @@ -100,6 +101,7 @@ class TransformerDiffusion(nn.Module): rotary_pos_emb=True, ) self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels) + self.type_embedding = nn.Embedding(types) # Either code_converter or latent_converter is used, depending on what type of conditioning data is fed. # This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally @@ -190,12 +192,13 @@ class TransformerDiffusion(nn.Module): return expanded_code_emb, cond_emb, mel_pred - def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, prenet_latent=None, precomputed_code_embeddings=None, + def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None, precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False): if precomputed_code_embeddings is not None: assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified" assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here." assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you." + assert type is not None, "Type is required." unused_params = [] if not return_code_pred: @@ -215,9 +218,10 @@ class TransformerDiffusion(nn.Module): unused_params.append(self.unconditioned_embedding) clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input) + type_emb = self.type_embedding(type) if clvp_input is None: unused_params.extend(self.clvp_encoder.parameters()) - blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb x = self.inp_block(x).permute(0,2,1) rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device) @@ -239,7 +243,7 @@ class TransformerDiffusion(nn.Module): @register_model -def register_transformer_diffusion5(opt_net, opt): +def register_transformer_diffusion_tts(opt_net, opt): return TransformerDiffusion(**opt_net['kwargs'])