tts9 mods
This commit is contained in:
parent
08599b4c75
commit
22c67ce8d3
|
@ -139,6 +139,7 @@ class DiffusionTts(nn.Module):
|
|||
in_channels=1,
|
||||
in_latent_channels=1024,
|
||||
in_tokens=8193,
|
||||
conditioning_expansion=4,
|
||||
out_channels=2, # mean and variance
|
||||
dropout=0,
|
||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||
|
@ -232,6 +233,7 @@ class DiffusionTts(nn.Module):
|
|||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||
)
|
||||
self.conditioning_expansion = conditioning_expansion
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
|
@ -430,6 +432,7 @@ class DiffusionTts(nn.Module):
|
|||
code_emb)
|
||||
|
||||
# Everything after this comment is timestep dependent.
|
||||
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||
|
||||
first = True
|
||||
|
@ -454,6 +457,13 @@ class DiffusionTts(nn.Module):
|
|||
h = h.float()
|
||||
out = self.out(h)
|
||||
|
||||
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||
extraneous_addition = 0
|
||||
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) + list(self.code_converter.parameters())
|
||||
for p in params:
|
||||
extraneous_addition = extraneous_addition + p.mean()
|
||||
out = out + extraneous_addition * 0
|
||||
|
||||
return out[:, :, :orig_x_shape]
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user