prep flat0 for feeding from autoregressive_latent_converter
This commit is contained in:
parent
3e97abc8a9
commit
f6a8b0a5ca
|
@ -118,6 +118,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
dropout=0,
|
dropout=0,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
num_heads=16,
|
num_heads=16,
|
||||||
|
freeze_everything_except_autoregressive_inputs=False,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
layer_drop=.1,
|
layer_drop=.1,
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
|
@ -151,7 +152,11 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
)
|
)
|
||||||
self.code_norm = normalization(model_channels)
|
self.code_norm = normalization(model_channels)
|
||||||
self.latent_converter = nn.Conv1d(in_latent_channels, model_channels, 1)
|
self.autoregressive_latent_converter = nn.Sequential(nn.Conv1d(in_latent_channels, model_channels, 1),
|
||||||
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
|
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
|
||||||
|
)
|
||||||
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
self.contextual_embedder = nn.Sequential(nn.Conv1d(in_channels,model_channels,3,padding=1,stride=2),
|
||||||
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
nn.Conv1d(model_channels, model_channels*2,3,padding=1,stride=2),
|
||||||
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
AttentionBlock(model_channels*2, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
|
||||||
|
@ -177,11 +182,19 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if freeze_everything_except_autoregressive_inputs:
|
||||||
|
for ap in list(self.autoregressive_latent_converter.parameters()):
|
||||||
|
ap.ALLOWED_IN_FLAT = True
|
||||||
|
for p in self.parameters():
|
||||||
|
if not hasattr(p, 'ALLOWED_IN_FLAT'):
|
||||||
|
p.requires_grad = False
|
||||||
|
p.DO_NOT_TRAIN = True
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
groups = {
|
groups = {
|
||||||
'minicoder': list(self.contextual_embedder.parameters()),
|
'minicoder': list(self.contextual_embedder.parameters()),
|
||||||
'layers': list(self.layers.parameters()),
|
'layers': list(self.layers.parameters()),
|
||||||
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.latent_converter.parameters()) + list(self.latent_converter.parameters()),
|
'code_converters': list(self.code_embedding.parameters()) + list(self.code_converter.parameters()) + list(self.autoregressive_latent_converter.parameters()) + list(self.autoregressive_latent_converter.parameters()),
|
||||||
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
|
'timestep_integrator': list(self.conditioning_timestep_integrator.parameters()) + list(self.integrating_conv.parameters()),
|
||||||
'time_embed': list(self.time_embed.parameters()),
|
'time_embed': list(self.time_embed.parameters()),
|
||||||
}
|
}
|
||||||
|
@ -202,7 +215,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
cond_emb = conds.mean(dim=-1)
|
cond_emb = conds.mean(dim=-1)
|
||||||
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
cond_scale, cond_shift = torch.chunk(cond_emb, 2, dim=1)
|
||||||
if is_latent(aligned_conditioning):
|
if is_latent(aligned_conditioning):
|
||||||
code_emb = self.latent_converter(aligned_conditioning)
|
code_emb = self.autoregressive_latent_converter(aligned_conditioning)
|
||||||
else:
|
else:
|
||||||
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
code_emb = self.code_embedding(aligned_conditioning).permute(0, 2, 1)
|
||||||
code_emb = self.code_converter(code_emb)
|
code_emb = self.code_converter(code_emb)
|
||||||
|
@ -245,7 +258,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
if conditioning_free:
|
if conditioning_free:
|
||||||
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
code_emb = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
|
||||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||||
unused_params.extend(list(self.latent_converter.parameters()))
|
unused_params.extend(list(self.autoregressive_latent_converter.parameters()))
|
||||||
else:
|
else:
|
||||||
if precomputed_aligned_embeddings is not None:
|
if precomputed_aligned_embeddings is not None:
|
||||||
code_emb = precomputed_aligned_embeddings
|
code_emb = precomputed_aligned_embeddings
|
||||||
|
@ -256,7 +269,7 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
if is_latent(aligned_conditioning):
|
if is_latent(aligned_conditioning):
|
||||||
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
unused_params.extend(list(self.code_converter.parameters()) + list(self.code_embedding.parameters()))
|
||||||
else:
|
else:
|
||||||
unused_params.extend(list(self.latent_converter.parameters()))
|
unused_params.extend(list(self.autoregressive_latent_converter.parameters()))
|
||||||
|
|
||||||
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
time_emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||||
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
code_emb = self.conditioning_timestep_integrator(code_emb, time_emb)
|
||||||
|
@ -297,7 +310,7 @@ if __name__ == '__main__':
|
||||||
aligned_sequence = torch.randint(0,8192,(2,100))
|
aligned_sequence = torch.randint(0,8192,(2,100))
|
||||||
cond = torch.randn(2, 100, 400)
|
cond = torch.randn(2, 100, 400)
|
||||||
ts = torch.LongTensor([600, 600])
|
ts = torch.LongTensor([600, 600])
|
||||||
model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5)
|
model = DiffusionTtsFlat(512, layer_drop=.3, unconditioned_percentage=.5, freeze_everything_except_autoregressive_inputs=True)
|
||||||
# Test with latent aligned conditioning
|
# Test with latent aligned conditioning
|
||||||
#o = model(clip, ts, aligned_latent, cond)
|
#o = model(clip, ts, aligned_latent, cond)
|
||||||
# Test with sequence aligned conditioning
|
# Test with sequence aligned conditioning
|
||||||
|
|
Loading…
Reference in New Issue
Block a user