From fc0b291b21e17f18dac1ae43afeb2edf93fb22c7 Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 19 Jul 2022 16:32:17 -0600 Subject: [PATCH] do masking up proper --- codes/models/audio/music/transformer_diffusion13.py | 4 ++++ codes/models/audio/music/transformer_diffusion14.py | 4 +++- codes/trainer/eval/music_diffusion_fid.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index b2a252e2..1a7277bd 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -30,8 +30,12 @@ class SubBlock(nn.Module): self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1) self.ffnorm = nn.GroupNorm(8, contraction_dim) self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8) + self.mask_initialized = False def forward(self, x, blk_emb): + if self.mask is not None and not self.mask_initialized: + self.mask = self.mask.to(x.device) + self.mask_initialized = True blk_enc = self.blk_emb_proj(blk_emb) ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask)) ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x. diff --git a/codes/models/audio/music/transformer_diffusion14.py b/codes/models/audio/music/transformer_diffusion14.py index 26d7eced..9143409e 100644 --- a/codes/models/audio/music/transformer_diffusion14.py +++ b/codes/models/audio/music/transformer_diffusion14.py @@ -25,12 +25,14 @@ class SubBlock(nn.Module): if self.enable_attention_masking: # All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region. self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1) + self.mask_initialized = False else: self.mask = None def forward(self, x, blk_emb): - if self.mask is not None: + if self.mask is not None and not self.mask_initialized: self.mask = self.mask.to(x.device) + self.mask_initialized = True blk_enc = self.blk_emb_proj(blk_emb) ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask)) ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x. diff --git a/codes/trainer/eval/music_diffusion_fid.py b/codes/trainer/eval/music_diffusion_fid.py index af79598e..b728b68a 100644 --- a/codes/trainer/eval/music_diffusion_fid.py +++ b/codes/trainer/eval/music_diffusion_fid.py @@ -165,6 +165,7 @@ class MusicDiffusionFid(evaluator.Evaluator): real_wav = pixel_shuffle_1d(real_wav, 16) return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate + def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050): audio = audio.unsqueeze(0)