forked from mrq/DL-Art-School
Flat fixes
This commit is contained in:
parent
26dcf7f1a2
commit
1ad18d29a8
|
@ -8,7 +8,7 @@ import torch.nn as nn
|
||||||
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
from x_transformers.x_transformers import groupby_prefix_and_trim, FixedPositionalEmbedding, default, RotaryEmbedding, \
|
||||||
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \
|
DEFAULT_DIM_HEAD, RelativePositionBias, LearnedAlibiPositionalBias, AlibiPositionalBias, ScaleNorm, RMSNorm, Rezero, \
|
||||||
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
|
exists, Attention, FeedForward, Scale, ShiftTokens, GRUGating, Residual, cast_tuple, equals, LayerIntermediates, \
|
||||||
AttentionLayers
|
AttentionLayers, not_equals
|
||||||
|
|
||||||
|
|
||||||
class TimeIntegrationBlock(nn.Module):
|
class TimeIntegrationBlock(nn.Module):
|
||||||
|
@ -138,19 +138,16 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
||||||
default_block = ('f',) + default_block
|
default_block = ('f',) + default_block
|
||||||
|
|
||||||
# qk normalization
|
# qk normalization
|
||||||
|
|
||||||
if use_qk_norm_attn:
|
if use_qk_norm_attn:
|
||||||
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
|
attn_scale_init_value = -math.log(math.log2(qk_norm_attn_seq_len ** 2 - qk_norm_attn_seq_len)) if exists(qk_norm_attn_seq_len) else None
|
||||||
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
attn_kwargs = {**attn_kwargs, 'qk_norm': True, 'scale_init_value': attn_scale_init_value}
|
||||||
|
|
||||||
# zero init
|
# zero init
|
||||||
|
|
||||||
if zero_init_branch_output:
|
if zero_init_branch_output:
|
||||||
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
attn_kwargs = {**attn_kwargs, 'zero_init_output': True}
|
||||||
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
ff_kwargs = {**ff_kwargs, 'zero_init_output': True}
|
||||||
|
|
||||||
# calculate layer block order
|
# calculate layer block order
|
||||||
|
|
||||||
if exists(custom_layers):
|
if exists(custom_layers):
|
||||||
layer_types = custom_layers
|
layer_types = custom_layers
|
||||||
elif exists(par_ratio):
|
elif exists(par_ratio):
|
||||||
|
@ -174,7 +171,6 @@ class TimestepEmbeddingAttentionLayers(AttentionLayers):
|
||||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||||
|
|
||||||
# calculate token shifting
|
# calculate token shifting
|
||||||
|
|
||||||
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
shift_tokens = cast_tuple(shift_tokens, len(layer_types))
|
||||||
|
|
||||||
# iterate and construct layers
|
# iterate and construct layers
|
||||||
|
|
|
@ -1,20 +1,13 @@
|
||||||
import random
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import autocast
|
|
||||||
from x_transformers import Encoder
|
from x_transformers import Encoder
|
||||||
|
|
||||||
from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers
|
from models.audio.tts.diffusion_encoder import TimestepEmbeddingAttentionLayers
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
|
||||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
|
||||||
Downsample, Upsample, TimestepBlock
|
|
||||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
||||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
from utils.util import checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
def is_latent(t):
|
def is_latent(t):
|
||||||
|
@ -139,8 +132,10 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
ff_glu=True,
|
ff_glu=True,
|
||||||
rotary_emb_dim=True,
|
rotary_emb_dim=True,
|
||||||
layerdrop_percent=layer_drop,
|
layerdrop_percent=layer_drop,
|
||||||
|
zero_init_branch_output=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self.layers.transformer.norm = nn.Identity() # We don't want the final norm for the main encoder.
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(model_channels),
|
normalization(model_channels),
|
||||||
|
@ -151,11 +146,12 @@ class DiffusionTtsFlat(nn.Module):
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
groups = {
|
groups = {
|
||||||
'minicoder': list(self.contextual_embedder.parameters()),
|
'minicoder': list(self.contextual_embedder.parameters()),
|
||||||
'layers': list(self.layers),
|
'conditioning_timestep_integrator': list(self.conditioning_timestep_integrator.parameters()),
|
||||||
|
'layers': list(self.layers.parameters()),
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False):
|
def forward(self, x, timesteps, aligned_conditioning, conditioning_input, conditioning_free=False):
|
||||||
"""
|
"""
|
||||||
Apply the model to an input batch.
|
Apply the model to an input batch.
|
||||||
|
|
||||||
|
|
|
@ -267,10 +267,10 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator',
|
diffusion = load_model_from_config('X:\\dlas\\experiments\\train_diffusion_tts9_mel.yml', 'generator',
|
||||||
also_load_savepoint=False,
|
also_load_savepoint=False,
|
||||||
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\5000_generator_ema.pth').cuda()
|
load_path='X:\\dlas\\experiments\\train_diffusion_tts9_mel\\models\\47500_generator_ema.pth').cuda()
|
||||||
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
opt_eval = {'eval_tsv': 'Y:\\libritts\\test-clean\\transcribed-brief-w2v.tsv', 'diffusion_steps': 100,
|
||||||
'conditioning_free': False, 'conditioning_free_k': 1,
|
'conditioning_free': True, 'conditioning_free_k': 1,
|
||||||
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
'diffusion_schedule': 'linear', 'diffusion_type': 'tts9_mel'}
|
||||||
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 555, 'device': 'cuda', 'opt': {}}
|
env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval', 'step': 557, 'device': 'cuda', 'opt': {}}
|
||||||
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
eval = AudioDiffusionFid(diffusion, opt_eval, env)
|
||||||
print(eval.perform_eval())
|
print(eval.perform_eval())
|
||||||
|
|
Loading…
Reference in New Issue
Block a user