forked from mrq/DL-Art-School
tts9 mods
This commit is contained in:
parent
08599b4c75
commit
22c67ce8d3
|
@ -139,6 +139,7 @@ class DiffusionTts(nn.Module):
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
in_latent_channels=1024,
|
in_latent_channels=1024,
|
||||||
in_tokens=8193,
|
in_tokens=8193,
|
||||||
|
conditioning_expansion=4,
|
||||||
out_channels=2, # mean and variance
|
out_channels=2, # mean and variance
|
||||||
dropout=0,
|
dropout=0,
|
||||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||||
|
@ -232,6 +233,7 @@ class DiffusionTts(nn.Module):
|
||||||
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
AttentionBlock(conditioning_dim, num_heads=num_heads, num_head_channels=num_head_channels),
|
||||||
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
ResBlock(conditioning_dim, time_embed_dim, dropout, out_channels=conditioning_dim, dims=dims, kernel_size=1, use_scale_shift_norm=use_scale_shift_norm),
|
||||||
)
|
)
|
||||||
|
self.conditioning_expansion = conditioning_expansion
|
||||||
|
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
|
@ -430,6 +432,7 @@ class DiffusionTts(nn.Module):
|
||||||
code_emb)
|
code_emb)
|
||||||
|
|
||||||
# Everything after this comment is timestep dependent.
|
# Everything after this comment is timestep dependent.
|
||||||
|
code_emb = torch.repeat_interleave(code_emb, self.conditioning_expansion, dim=-1)
|
||||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||||
|
|
||||||
first = True
|
first = True
|
||||||
|
@ -454,6 +457,13 @@ class DiffusionTts(nn.Module):
|
||||||
h = h.float()
|
h = h.float()
|
||||||
out = self.out(h)
|
out = self.out(h)
|
||||||
|
|
||||||
|
# Involve probabilistic or possibly unused parameters in loss so we don't get DDP errors.
|
||||||
|
extraneous_addition = 0
|
||||||
|
params = [self.aligned_latent_padding_embedding, self.unconditioned_embedding] + list(self.latent_converter.parameters()) + list(self.code_converter.parameters())
|
||||||
|
for p in params:
|
||||||
|
extraneous_addition = extraneous_addition + p.mean()
|
||||||
|
out = out + extraneous_addition * 0
|
||||||
|
|
||||||
return out[:, :, :orig_x_shape]
|
return out[:, :, :orig_x_shape]
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user