diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 9ce7dcbd..056e6426 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -609,7 +609,6 @@ def test_vqvae_model(): o = model(clip, ts, cond) pg = model.get_grad_norm_parameter_groups() - """ with torch.no_grad(): proj = torch.randn(2, 100, 512).cuda() clip = clip.cuda() @@ -618,10 +617,9 @@ def test_vqvae_model(): model = model.cuda().eval() model.diff.enable_fp16 = True ti = model.diff.timestep_independent(proj, clip.shape[2]) - for k in range(100): + for k in range(1000): model.diff(clip, ts, precomputed_code_embeddings=ti) print(f"Elapsed: {time()-start}") - """ def test_multi_vqvae_model(): @@ -690,4 +688,5 @@ def extract_diff(in_f, out_f, remove_head=False): if __name__ == '__main__': #extract_diff('X:\\dlas\\experiments\\train_music_diffusion_tfd12\\models\\41000_generator_ema.pth', 'extracted_diff.pth', True) #test_cheater_model() - extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True) + test_vqvae_model() + #extract_diff('X:\\dlas\experiments\\train_music_diffusion_tfd_cheater_from_scratch\\models\\56500_generator_ema.pth', 'extracted.pth', remove_head=True) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index f2fc1080..e641b3e1 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -19,8 +19,8 @@ class SubBlock(nn.Module): super().__init__() self.dropout = nn.Dropout(p=dropout) self.attn = AttentionBlock(inp_dim, out_channels=contraction_dim, num_heads=heads) - self.register_buffer('mask', build_local_attention_mask(n=4000, l=64), persistent=False) - self.pos_bias = RelativeQKBias(l=64) + self.register_buffer('mask', build_local_attention_mask(n=6000, l=64), persistent=False) + self.pos_bias = RelativeQKBias(l=64, max_positions=6000) ff_contract = contraction_dim//2 self.ff1 = nn.Sequential(nn.Conv1d(inp_dim+contraction_dim, ff_contract, kernel_size=1), nn.GroupNorm(8, ff_contract), diff --git a/codes/models/audio/tts/unet_diffusion_tts_flat.py b/codes/models/audio/tts/unet_diffusion_tts_flat.py index ca02171f..2280d5ba 100644 --- a/codes/models/audio/tts/unet_diffusion_tts_flat.py +++ b/codes/models/audio/tts/unet_diffusion_tts_flat.py @@ -7,7 +7,8 @@ import torch.nn.functional as F from torch import autocast from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear -from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, TimestepBlock +from models.diffusion.unet_diffusion import TimestepEmbedSequential, TimestepBlock, QKVAttentionLegacy +from models.lucidrains.x_transformers import RelativePositionBias from trainer.networks import register_model from utils.util import checkpoint @@ -19,6 +20,52 @@ def is_sequence(t): return t.dtype == torch.long +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + do_checkpoint=True, + relative_pos_embeddings=False, + ): + super().__init__() + self.channels = channels + self.do_checkpoint = do_checkpoint + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.norm = normalization(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + if relative_pos_embeddings: + self.relative_pos_embeddings = RelativePositionBias(scale=(channels // self.num_heads) ** .5, causal=False, heads=num_heads, num_buckets=32, max_distance=64) + else: + self.relative_pos_embeddings = None + + def forward(self, x, mask=None): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv, mask, self.relative_pos_embeddings) + h = self.proj_out(h) + return (x + h).reshape(b, c, *spatial) + + class ResBlock(TimestepBlock): def __init__( self, @@ -336,7 +383,7 @@ if __name__ == '__main__': start = time() model = model.cuda().eval() ti = model.timestep_independent(proj, clip, clip.shape[2], False) - for k in range(100): + for k in range(1000): model(clip, ts, precomputed_aligned_embeddings=ti) print(f"Elapsed: {time()-start}") diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index 9c943a6e..d014c23c 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -137,8 +137,8 @@ class MusicDiffusionFid(evaluator.Evaluator): # s = q9.clamp(1, 9999999999) # x = x.clamp(-s, s) / s # return x - perp = self.diffuser.p_sample_loop_for_log_perplexity(self.model, mel_norm, - model_kwargs = {'truth_mel': mel_norm}) + #perp = self.diffuser.p_sample_loop_for_log_perplexity(self.model, mel_norm, + # model_kwargs = {'truth_mel': mel_norm}) sampler = self.diffuser.ddim_sample_loop if self.ddim else self.diffuser.p_sample_loop gen_mel = sampler(self.model, mel_norm.shape, model_kwargs={'truth_mel': mel_norm}) @@ -155,7 +155,7 @@ class MusicDiffusionFid(evaluator.Evaluator): model_kwargs={'codes': mel}) real_wav = pixel_shuffle_1d(real_wav, 16) - return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate, perp + return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate, torch.tensor([0]) def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050): audio = audio.unsqueeze(0) @@ -303,7 +303,7 @@ class MusicDiffusionFid(evaluator.Evaluator): gen_projections = torch.stack(gen_projections, dim=0) real_projections = torch.stack(real_projections, dim=0) frechet_distance = torch.tensor(self.compute_frechet_distance(gen_projections, real_projections), device=self.env['device']) - perplexity = torch.stack(perplexities, dim=0).mean() + perplexity = torch.stack(perplexities, dim=0).float().mean() if distributed.is_initialized() and distributed.get_world_size() > 1: distributed.all_reduce(frechet_distance) @@ -338,16 +338,17 @@ if __name__ == '__main__': # For TFD+cheater trainer diffusion = load_model_from_config('X:\\dlas\\experiments\\train_music_diffusion_tfd_and_cheater.yml', 'generator', also_load_savepoint=False, strict_load=False, - load_path='X:\\dlas\\experiments\\train_music_diffusion_tfd14_and_cheater_g2\\models\\120000_generator_ema.pth' + load_path='X:\\dlas\\experiments\\tfd14_and_cheater.pth' ).cuda() - opt_eval = {'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) + opt_eval = {#'path': 'Y:\\split\\yt-music-eval', # eval music, mostly electronica. :) #'path': 'E:\\music_eval', # this is music from the training dataset, including a lot more variety. + 'path': 'Y:\\separated\\tfd14_test', 'diffusion_steps': 256, - 'conditioning_free': False, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True, + 'conditioning_free': True, 'conditioning_free_k': 1, 'use_ddim': False, 'clip_audio': True, 'diffusion_schedule': 'cosine', 'diffusion_type': 'from_codes_quant', } - env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 13, 'device': 'cuda', 'opt': {}} + env = {'rank': 0, 'base_path': 'D:\\tmp\\test_eval_music', 'step': 18, 'device': 'cuda', 'opt': {}} eval = MusicDiffusionFid(diffusion, opt_eval, env) fds = [] for i in range(2): diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 512856ce..be187873 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -11,7 +11,7 @@ from utils.util import opt_get, load_model_from_config, pad_or_truncate MEL_MIN = -11.512925148010254 TACOTRON_MEL_MAX = 2.3143386840820312 -TORCH_MEL_MAX = 4.82 +TORCH_MEL_MAX = 4.82 # FYI: this STILL isn't assertive enough... def normalize_torch_mel(mel): return 2 * ((mel - MEL_MIN) / (TORCH_MEL_MAX - MEL_MIN)) - 1