forked from mrq/DL-Art-School
Flat fixes
This commit is contained in:
parent
26dcf7f1a2
commit
1ad18d29a8
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
|||
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
||||
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \
|
||||
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
|
||||
AttentionLayers
|
||||
AttentionLayers, not_equals
|
||||
|
||||
|
||||
class TimeIntegrationBlock(nn.Module):
|
||||
|
@ -138,19 +138,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
default_block = ('f',) + default_block
|
||||
|
||||
# qk normalization
|
||||
|
||||
if use_qk_norm_attn:
|
||||
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
|
||||
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
||||
|
||||
# zero init
|
||||
|
||||
if zero_init_branch_output:
|
||||
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
||||
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
||||
|
||||
# calculate layer block order
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
elif exists(par_ratio):
|
||||
|
@ -174,7 +171,6 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
|||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||
|
||||
# calculate token shifting
|
||||
|
||||
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
||||
|
||||
# iterate and construct layers
|
||||
|
|
|
@ -1,20 +1,13 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
from x_transformers import Encoder
|
||||
|
||||
from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def is_latent(t):
|
||||
|
@ -139,8 +132,10 @@ class DiffusionTtsFlat(nn.Module):
|
|||
ff_glu=True,
|
||||
rotary_emb_dim=True,
|
||||
layerdrop_percent=layer_drop,
|
||||
zero_init_branch_output=True,
|
||||
)
|
||||
)
|
||||
self.layers.transformer.norm = nn.Identity() # We don't want the final norm for the main encoder.
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(model_channels),
|
||||
|
@ -151,11 +146,12 @@ class DiffusionTtsFlat(nn.Module):
|
|||
def get_grad_norm_parameter_groups(self):
|
||||
groups = {
|
||||
'minicoder': list(self.contextual_embedder.parameters()),
|
||||
'layers': list(self.layers),
|
||||
'conditioning_timestep_integrator': list(self.conditioning_timestep_integrator.parameters()),
|
||||
'layers': list(self.layers.parameters()),
|
||||
}
|
||||
return groups
|
||||
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
|
||||
|
|
|
@ -267,10 +267,10 @@ if __name__ == '__main__':
|
|||
|
||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator',
|
||||
also_load_savepoint=False,
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\5000_generator_ema.pth').cuda()
|
||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\47500_generator_ema.pth').cuda()
|
||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
||||
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}}
|
||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 557, 'device': 'cuda', 'opt': {}}
|
||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||
print(eval.perform_eval())
|
||||
|
|
Loading…
Reference in New Issue
Block a user