forked from mrq/DL-Art-School
do masking up proper
This commit is contained in:
parent
b203a7dc97
commit
fc0b291b21
|
@ -30,8 +30,12 @@ class SubBlock(nn.Module):
|
||||||
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
||||||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
||||||
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
|
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
|
||||||
|
self.mask_initialized = False
|
||||||
|
|
||||||
def forward(self, x, blk_emb):
|
def forward(self, x, blk_emb):
|
||||||
|
if self.mask is not None and not self.mask_initialized:
|
||||||
|
self.mask = self.mask.to(x.device)
|
||||||
|
self.mask_initialized = True
|
||||||
blk_enc = self.blk_emb_proj(blk_emb)
|
blk_enc = self.blk_emb_proj(blk_emb)
|
||||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||||
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
||||||
|
|
|
@ -25,12 +25,14 @@ class SubBlock(nn.Module):
|
||||||
if self.enable_attention_masking:
|
if self.enable_attention_masking:
|
||||||
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
|
# All regions can attend to the first token, which will be the timestep embedding. Hence, fixed_region.
|
||||||
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1)
|
self.mask = build_local_attention_mask(n=2000, l=48, fixed_region=1)
|
||||||
|
self.mask_initialized = False
|
||||||
else:
|
else:
|
||||||
self.mask = None
|
self.mask = None
|
||||||
|
|
||||||
def forward(self, x, blk_emb):
|
def forward(self, x, blk_emb):
|
||||||
if self.mask is not None:
|
if self.mask is not None and not self.mask_initialized:
|
||||||
self.mask = self.mask.to(x.device)
|
self.mask = self.mask.to(x.device)
|
||||||
|
self.mask_initialized = True
|
||||||
blk_enc = self.blk_emb_proj(blk_emb)
|
blk_enc = self.blk_emb_proj(blk_emb)
|
||||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||||
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
||||||
|
|
|
@ -165,6 +165,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
real_wav = pixel_shuffle_1d(real_wav, 16)
|
real_wav = pixel_shuffle_1d(real_wav, 16)
|
||||||
|
|
||||||
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
return gen_wav, real_wav.squeeze(0), gen_mel, mel_norm, sample_rate
|
||||||
|
|
||||||
def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050):
|
def perform_reconstruction_from_cheater_gen(self, audio, sample_rate=22050):
|
||||||
audio = audio.unsqueeze(0)
|
audio = audio.unsqueeze(0)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user