propagate type
This commit is contained in:
parent
c46da0285c
commit
bed3df4888
|
@ -64,6 +64,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
rotary_emb_dim=32,
|
rotary_emb_dim=32,
|
||||||
token_count=8,
|
token_count=8,
|
||||||
in_groups=None,
|
in_groups=None,
|
||||||
|
types=2,
|
||||||
out_channels=512, # mean and variance
|
out_channels=512, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
@ -100,6 +101,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
rotary_pos_emb=True,
|
rotary_pos_emb=True,
|
||||||
)
|
)
|
||||||
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
|
self.clvp_encoder = nn.Linear(clvp_in_dim, model_channels)
|
||||||
|
self.type_embedding = nn.Embedding(types)
|
||||||
|
|
||||||
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
|
||||||
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
|
||||||
|
@ -190,12 +192,13 @@ class TransformerDiffusion(nn.Module):
|
||||||
return expanded_code_emb, cond_emb, mel_pred
|
return expanded_code_emb, cond_emb, mel_pred
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, prenet_latent=None, precomputed_code_embeddings=None,
|
def forward(self, x, timesteps, codes=None, conditioning_input=None, clvp_input=None, type=None, prenet_latent=None, precomputed_code_embeddings=None,
|
||||||
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
|
precomputed_cond_embeddings=None, conditioning_free=False, return_code_pred=False):
|
||||||
if precomputed_code_embeddings is not None:
|
if precomputed_code_embeddings is not None:
|
||||||
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
|
assert precomputed_cond_embeddings is not None, "Must specify both precomputed embeddings if one is specified"
|
||||||
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
assert codes is None and conditioning_input is None and prenet_latent is None, "Do not provide precomputed embeddings and the other parameters. It is unclear what you want me to do here."
|
||||||
assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
|
assert not (return_code_pred and precomputed_code_embeddings is not None), "I cannot compute a code_pred output for you."
|
||||||
|
assert type is not None, "Type is required."
|
||||||
|
|
||||||
unused_params = []
|
unused_params = []
|
||||||
if not return_code_pred:
|
if not return_code_pred:
|
||||||
|
@ -215,9 +218,10 @@ class TransformerDiffusion(nn.Module):
|
||||||
unused_params.append(self.unconditioned_embedding)
|
unused_params.append(self.unconditioned_embedding)
|
||||||
|
|
||||||
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
|
clvp_emb = torch.zeros_like(cond_emb) if clvp_input is None else self.clvp_encoder(clvp_input)
|
||||||
|
type_emb = self.type_embedding(type)
|
||||||
if clvp_input is None:
|
if clvp_input is None:
|
||||||
unused_params.extend(self.clvp_encoder.parameters())
|
unused_params.extend(self.clvp_encoder.parameters())
|
||||||
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb
|
blk_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + cond_emb + clvp_emb + type_emb
|
||||||
x = self.inp_block(x).permute(0,2,1)
|
x = self.inp_block(x).permute(0,2,1)
|
||||||
|
|
||||||
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
rotary_pos_emb = self.rotary_embeddings(x.shape[1], x.device)
|
||||||
|
@ -239,7 +243,7 @@ class TransformerDiffusion(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_transformer_diffusion5(opt_net, opt):
|
def register_transformer_diffusion_tts(opt_net, opt):
|
||||||
return TransformerDiffusion(**opt_net['kwargs'])
|
return TransformerDiffusion(**opt_net['kwargs'])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user