forked from mrq/DL-Art-School
do masking up proper
This commit is contained in:
parent
b203a7dc97
commit
fc0b291b21
|
@ -30,8 +30,12 @@ class SubBlock(nn.Module):
|
|||
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
||||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
||||
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
|
||||
self.mask_initialized = False
|
||||
|
||||
def forward(self, x, blk_emb):
|
||||
if self.mask is not None and not self.mask_initialized:
|
||||
self.mask = self.mask.to(x.device)
|
||||
self.mask_initialized = True
|
||||
blk_enc = self.blk_emb_proj(blk_emb)
|
||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
||||
|
|
|
@ -25,12 +25,14 @@ class SubBlock(nn.Module):
|
|||
if self.enable_attention_masking:
|
||||
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
|
||||
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1)
|
||||
self.mask_initialized = False
|
||||
else:
|
||||
self.mask = None
|
||||
|
||||
def forward(self, x, blk_emb):
|
||||
if self.mask is not None:
|
||||
if self.mask is not None and not self.mask_initialized:
|
||||
self.mask = self.mask.to(x.device)
|
||||
self.mask_initialized = True
|
||||
blk_enc = self.blk_emb_proj(blk_emb)
|
||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
||||
|
|
|
@ -165,6 +165,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
|||
real_wav = pixel_shuffle_1d(real_wav, 16)
|
||||
|
||||
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
||||
|
||||
def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050):
|
||||
audio = audio.unsqueeze(0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user