Flat fixes

This commit is contained in:
James Betker 2022-03-21 14:43:52 -06:00
parent 26dcf7f1a2
commit 1ad18d29a8
3 changed files with 10 additions and 18 deletions

View File

@ -8,7 +8,7 @@ import torch.nn as nn
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \ from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \ DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \ exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
AttentionLayers AttentionLayers, not_equals
class TimeIntegrationBlock(nn.Module): class TimeIntegrationBlock(nn.Module):
@ -138,19 +138,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
default_block = ('f',) + default_block default_block = ('f',) + default_block
# qk normalization # qk normalization
if use_qk_norm_attn: if use_qk_norm_attn:
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value} attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
# zero init # zero init
if zero_init_branch_output: if zero_init_branch_output:
attn_kwargs = {**attn_kwargs, 'zero_init_output': True} attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
ff_kwargs = {**ff_kwargs, 'zero_init_output': True} ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
# calculate layer block order # calculate layer block order
if exists(custom_layers): if exists(custom_layers):
layer_types = custom_layers layer_types = custom_layers
elif exists(par_ratio): elif exists(par_ratio):
@ -174,7 +171,6 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
self.num_attn_layers = len(list(filter(equals('a'), layer_types))) self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
# calculate token shifting # calculate token shifting
shift_tokens = cast_tuple(shift_tokens, len(layer_types)) shift_tokens = cast_tuple(shift_tokens, len(layer_types))
# iterate and construct layers # iterate and construct layers

View File

@ -1,20 +1,13 @@
import random
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import autocast
from x_transformers import Encoder from x_transformers import Encoder
from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
Downsample, Upsample, TimestepBlock
from models.audio.tts.mini_encoder import AudioMiniEncoder from models.audio.tts.mini_encoder import AudioMiniEncoder
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
from scripts.audio.gen.use_diffuse_tts import ceil_multiple from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import checkpoint
def is_latent(t): def is_latent(t):
@ -139,8 +132,10 @@ class DiffusionTtsFlat(nn.Module):
ff_glu=True, ff_glu=True,
rotary_emb_dim=True, rotary_emb_dim=True,
layerdrop_percent=layer_drop, layerdrop_percent=layer_drop,
zero_init_branch_output=True,
) )
) )
self.layers.transformer.norm = nn.Identity() # We don't want the final norm for the main encoder.
self.out = nn.Sequential( self.out = nn.Sequential(
normalization(model_channels), normalization(model_channels),
@ -151,11 +146,12 @@ class DiffusionTtsFlat(nn.Module):
def get_grad_norm_parameter_groups(self): def get_grad_norm_parameter_groups(self):
groups = { groups = {
'minicoder': list(self.contextual_embedder.parameters()), 'minicoder': list(self.contextual_embedder.parameters()),
'layers': list(self.layers), 'conditioning_timestep_integrator': list(self.conditioning_timestep_integrator.parameters()),
'layers': list(self.layers.parameters()),
} }
return groups return groups
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False): def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False):
""" """
Apply the model to an input batch. Apply the model to an input batch.

View File

@ -267,10 +267,10 @@ if __name__ == '__main__':
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator', diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator',
also_load_savepoint=False, also_load_savepoint=False,
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\5000_generator_ema.pth').cuda() load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\47500_generator_ema.pth').cuda()
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100, opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
'conditioning_free': False, 'conditioning_free_k': 1, 'conditioning_free': True, 'conditioning_free_k': 1,
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'} 'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}} env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 557, 'device': 'cuda', 'opt': {}}
eval = AudioDiffusionFid(diffusion, opt_eval, env) eval = AudioDiffusionFid(diffusion, opt_eval, env)
print(eval.perform_eval()) print(eval.perform_eval())