tts9 mods

This commit is contained in:
James Betker 2022-03-13 10:25:55 -06:00
parent 08599b4c75
commit 22c67ce8d3

View File

@ -139,6 +139,7 @@ class DiffusionTts(nn.Module):
in_channels=1, in_channels=1,
in_latent_channels=1024, in_latent_channels=1024,
in_tokens=8193, in_tokens=8193,
conditioning_expansion=4,
out_channels=2, # mean and variance out_channels=2, # mean and variance
dropout=0, dropout=0,
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K # res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
@ -232,6 +233,7 @@ class DiffusionTts(nn.Module):
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels), AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm), ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
) )
self.conditioning_expansion = conditioning_expansion
self.input_blocks = nn.ModuleList( self.input_blocks = nn.ModuleList(
[ [
@ -430,6 +432,7 @@ class DiffusionTts(nn.Module):
code_emb) code_emb)
# Everything after this comment is timestep dependent. # Everything after this comment is timestep dependent.
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb) code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
first = True first = True
@ -454,6 +457,13 @@ class DiffusionTts(nn.Module):
h = h.float() h = h.float()
out = self.out(h) out = self.out(h)
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
extraneous_addition = 0
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) + list(self.code_converter.parameters())
for p in params:
extraneous_addition = extraneous_addition + p.mean()
out = out + extraneous_addition * 0
return out[:, :, :orig_x_shape] return out[:, :, :orig_x_shape]