propagate type

This commit is contained in:
James Betker 2022-05-27 11:12:03 -06:00
parent c46da0285c
commit bed3df4888

View File

@ -64,6 +64,7 @@ class TransformerDiffusion(nn.Module):
rotary_emb_dim=32,
token_count=8,
in_groups=None,
types=2,
out_channels=512, # mean and variance
dropout=0,
use_fp16=False,
@ -100,6 +101,7 @@ class TransformerDiffusion(nn.Module):
rotary_pos_emb=True,
)
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
self.type_embedding = nn.Embedding(types)
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
@ -190,12 +192,13 @@ class TransformerDiffusion(nn.Module):
return expanded_code_emb, cond_emb, mel_pred
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, prenet_latent=None, precomputed_code_embeddings=None,
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None,
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
if precomputed_code_embeddings is not None:
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
assert type is not None, "Type is required."
unused_params = []
if not return_code_pred:
@ -215,9 +218,10 @@ class TransformerDiffusion(nn.Module):
unused_params.append(self.unconditioned_embedding)
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
type_emb = self.type_embedding(type)
if clvp_input is None:
unused_params.extend(self.clvp_encoder.parameters())
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb
x = self.inp_block(x).permute(0,2,1)
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
@ -239,7 +243,7 @@ class TransformerDiffusion(nn.Module):
@register_model
def register_transformer_diffusion5(opt_net, opt):
def register_transformer_diffusion_tts(opt_net, opt):
return TransformerDiffusion(**opt_net['kwargs'])