forked from mrq/DL-Art-School
propagate type
This commit is contained in:
parent
c46da0285c
commit
bed3df4888
|
@ -64,6 +64,7 @@ class TransformerDiffusion(nn.Module):
|
|||
rotary_emb_dim=32,
|
||||
token_count=8,
|
||||
in_groups=None,
|
||||
types=2,
|
||||
out_channels=512, # mean and variance
|
||||
dropout=0,
|
||||
use_fp16=False,
|
||||
|
@ -100,6 +101,7 @@ class TransformerDiffusion(nn.Module):
|
|||
rotary_pos_emb=True,
|
||||
)
|
||||
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
|
||||
self.type_embedding = nn.Embedding(types)
|
||||
|
||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||
|
@ -190,12 +192,13 @@ class TransformerDiffusion(nn.Module):
|
|||
return expanded_code_emb, cond_emb, mel_pred
|
||||
|
||||
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, prenet_latent=None, precomputed_code_embeddings=None,
|
||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None,
|
||||
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||
if precomputed_code_embeddings is not None:
|
||||
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
|
||||
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||
assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
|
||||
assert type is not None, "Type is required."
|
||||
|
||||
unused_params = []
|
||||
if not return_code_pred:
|
||||
|
@ -215,9 +218,10 @@ class TransformerDiffusion(nn.Module):
|
|||
unused_params.append(self.unconditioned_embedding)
|
||||
|
||||
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
|
||||
type_emb = self.type_embedding(type)
|
||||
if clvp_input is None:
|
||||
unused_params.extend(self.clvp_encoder.parameters())
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb
|
||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb
|
||||
x = self.inp_block(x).permute(0,2,1)
|
||||
|
||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||
|
@ -239,7 +243,7 @@ class TransformerDiffusion(nn.Module):
|
|||
|
||||
|
||||
@register_model
|
||||
def register_transformer_diffusion5(opt_net, opt):
|
||||
def register_transformer_diffusion_tts(opt_net, opt):
|
||||
return TransformerDiffusion(**opt_net['kwargs'])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user