do masking up proper

This commit is contained in:
James Betker 2022-07-19 16:32:17 -06:00
parent b203a7dc97
commit fc0b291b21
3 changed files with 8 additions and 1 deletions

View File

@ -30,8 +30,12 @@ class SubBlock(nn.Module):
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
self.ffnorm = nn.GroupNorm(8, contraction_dim)
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
self.mask_initialized = False
def forward(self, x, blk_emb):
if self.mask is not None and not self.mask_initialized:
self.mask = self.mask.to(x.device)
self.mask_initialized = True
blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.

View File

@ -25,12 +25,14 @@ class SubBlock(nn.Module):
if self.enable_attention_masking:
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1)
self.mask_initialized = False
else:
self.mask = None
def forward(self, x, blk_emb):
if self.mask is not None:
if self.mask is not None and not self.mask_initialized:
self.mask = self.mask.to(x.device)
self.mask_initialized = True
blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.

View File

@ -165,6 +165,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
real_wav = pixel_shuffle_1d(real_wav, 16)
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050):
audio = audio.unsqueeze(0)