Updates to support inputting MELs into the conditioning encoder

This commit is contained in:
James Betker 2022-03-14 17:31:42 -06:00
parent e045fb0ad7
commit f8631ad4f7

View File

@ -139,6 +139,7 @@ class DiffusionTts(nn.Module):
in_channels=1,
in_latent_channels=1024,
in_tokens=8193,
conditioning_dim_factor=8,
conditioning_expansion=4,
out_channels=2, # mean and variance
dropout=0,
@ -198,7 +199,7 @@ class DiffusionTts(nn.Module):
linear(time_embed_dim, time_embed_dim),
)
conditioning_dim = model_channels * 8
conditioning_dim = model_channels * conditioning_dim_factor
# Either code_converter or latent_converter is used, depending on what type of conditioning data is fed.
# This model is meant to be able to be trained on both for efficiency purposes - it is far less computationally
# complex to generate tokens, while generating latents will normally mean propagating through a deep autoregressive
@ -222,8 +223,26 @@ class DiffusionTts(nn.Module):
))
self.latent_converter = nn.Conv1d(in_latent_channels, conditioning_dim, 1)
self.aligned_latent_padding_embedding = nn.Parameter(torch.randn(1,in_latent_channels,1))
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
if in_channels == 80:
self.contextual_embedder = nn.Sequential(nn.Conv1d(80,conditioning_dim,3,padding=1,stride=2),
CheckpointedXTransformerEncoder(
needs_permute=True,
max_seq_len=-1,
use_pos_emb=False,
attn_layers=Encoder(
dim=conditioning_dim,
depth=4,
heads=num_heads,
ff_dropout=dropout,
attn_dropout=dropout,
use_rmsnorm=True,
ff_glu=True,
rotary_emb_dim=True,
)
))
else:
self.contextual_embedder = AudioMiniEncoder(1, conditioning_dim, base_channels=32, depth=6, resnet_blocks=1,
attn_blocks=3, num_attn_heads=8, dropout=dropout, downsample_factor=4, kernel_size=5)
self.conditioning_conv = nn.Conv1d(conditioning_dim*2, conditioning_dim, 1)
self.unconditioned_embedding = nn.Parameter(torch.randn(1,conditioning_dim,1))
self.conditioning_timestep_integrator = TimestepEmbedSequential(
@ -418,6 +437,8 @@ class DiffusionTts(nn.Module):
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, 1)
else:
cond_emb = self.contextual_embedder(conditioning_input)
if len(cond_emb.shape) == 3: # Just take the first element.
cond_emb = cond_emb[:, :, 0]
if is_latent(aligned_conditioning):
code_emb = self.latent_converter(aligned_conditioning)
else:
@ -459,7 +480,7 @@ class DiffusionTts(nn.Module):
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) + list(self.code_converter.parameters())
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters())
for p in params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0