diff --git a/codes/models/audio/tts/diffusion_encoder.py b/codes/models/audio/tts/diffusion_encoder.py index b408e6fc..629bfaed 100644 --- a/codes/models/audio/tts/diffusion_encoder.py +++ b/codes/models/audio/tts/diffusion_encoder.py @@ -8,7 +8,7 @@ import torch.nn as nn from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \ exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \ - AttentionLayers + AttentionLayers, not_equals class TimeIntegrationBlock(nn.Module): @@ -138,19 +138,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): default_block = ('f',) + default_block # qk normalization - if use_qk_norm_attn: attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} # zero init - if zero_init_branch_output: attn_kwargs = {**attn_kwargs, 'zero_init_output': True} ff_kwargs = {**ff_kwargs, 'zero_init_output': True} # calculate layer block order - if exists(custom_layers): layer_types = custom_layers elif exists(par_ratio): @@ -174,7 +171,6 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers): self.num_attn_layers = len(list(filter(equals('a'), layer_types))) # calculate token shifting - shift_tokens = cast_tuple(shift_tokens, len(layer_types)) # iterate and construct layers diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index b7c91031..21a208b5 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -1,20 +1,13 @@ -import random - import torch import torch.nn as nn import torch.nn.functional as F -from torch import autocast from x_transformers import Encoder from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers -from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \ - Downsample, Upsample, TimestepBlock from models.audio.tts.mini_encoder import AudioMiniEncoder from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder -from scripts.audio.gen.use_diffuse_tts import ceil_multiple +from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from trainer.networks import register_model -from utils.util import checkpoint def is_latent(t): @@ -139,8 +132,10 @@ class DiffusionTtsFlat(nn.Module): ff_glu=True, rotary_emb_dim=True, layerdrop_percent=layer_drop, + zero_init_branch_output=True, ) ) + self.layers.transformer.norm = nn.Identity() # We don't want the final norm for the main encoder. self.out = nn.Sequential( normalization(model_channels), @@ -151,11 +146,12 @@ class DiffusionTtsFlat(nn.Module): def get_grad_norm_parameter_groups(self): groups = { 'minicoder': list(self.contextual_embedder.parameters()), - 'layers': list(self.layers), + 'conditioning_timestep_integrator': list(self.conditioning_timestep_integrator.parameters()), + 'layers': list(self.layers.parameters()), } return groups - def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False): + def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False): """ Apply the model to an input batch. diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index 3f27e2f5..c91a93da 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -267,10 +267,10 @@ if __name__ == '__main__': diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator', also_load_savepoint=False, - load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\5000_generator_ema.pth').cuda() + load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\47500_generator_ema.pth').cuda() opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100, - 'conditioning_free': False, 'conditioning_free_k': 1, + 'conditioning_free': True, 'conditioning_free_k': 1, 'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'} - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 557, 'device': 'cuda', 'opt': {}} eval = AudioDiffusionFid(diffusion, opt_eval, env) print(eval.perform_eval())