Flat fixes

This commit is contained in:
James Betker 2022-03-21 14:43:52 -06:00
parent 26dcf7f1a2
commit 1ad18d29a8
3 changed files with 10 additions and 18 deletions

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
AttentionLayers
AttentionLayers, not_equals
class TimeIntegrationBlock(nn.Module):
@ -138,19 +138,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
default_block = ('f',) + default_block
# qk normalization
if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
# zero init
if zero_init_branch_output:
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
# calculate layer block order
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
@ -174,7 +171,6 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
# calculate token shifting
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
# iterate and construct layers

View File

@ -1,20 +1,13 @@
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder
from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from models.audio.tts.mini_encoder import AudioMiniEncoder
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from trainer.networks import register_model
from utils.util import checkpoint
def is_latent(t):
@ -139,8 +132,10 @@ class DiffusionTtsFlat(nn.Module):
ff_glu=True,
rotary_emb_dim=True,
layerdrop_percent=layer_drop,
zero_init_branch_output=True,
)
)
self.layers.transformer.norm = nn.Identity() # We don't want the final norm for the main encoder.
self.out = nn.Sequential(
normalization(model_channels),
@ -151,11 +146,12 @@ class DiffusionTtsFlat(nn.Module):
def get_grad_norm_parameter_groups(self):
groups = {
'minicoder': list(self.contextual_embedder.parameters()),
'layers': list(self.layers),
'conditioning_timestep_integrator': list(self.conditioning_timestep_integrator.parameters()),
'layers': list(self.layers.parameters()),
}
return groups
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False):
"""
Apply the model to an input batch.

View File

@ -267,10 +267,10 @@ if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator',
also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\5000_generator_ema.pth').cuda()
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\47500_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1,
'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 557, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval())