diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index d701991c..e2c4c4a8 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -439,6 +439,32 @@ class ResBlock(nn.Module): return self.skip_connection(x) + h +def build_local_attention_mask(n, l, fixed_region): + """ + Builds an attention mask that focuses attention on local region + Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied. + Args: + n: Size of returned matrix (maximum sequence size) + l: Size of local context (uni-directional, e.g. the total context is l*2) + fixed_region: The number of sequence elements at the start of the sequence that get full attention. + Returns: + A mask that can be applied to AttentionBlock to achieve local attention. + """ + assert l*2 < n, f'Local context must be less than global context. {l}, {n}' + o = torch.arange(0,n) + c = o.unsqueeze(-1).repeat(1,n) + r = o.unsqueeze(0).repeat(n,1) + localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1) + localized[:fixed_region] = 1 + localized[:, :fixed_region] = 1 + mask = localized > 0 + return mask + + +def test_local_attention_mask(): + print(build_local_attention_mask(9,4,1)) + + class AttentionBlock(nn.Module): """ An attention block that allows spatial positions to attend to each other. @@ -492,10 +518,11 @@ class AttentionBlock(nn.Module): def _forward(self, x, mask=None): b, c, *spatial = x.shape - if len(mask.shape) == 2: - mask = mask.unsqueeze(0).repeat(x.shape[0],1,1) - if mask.shape[1] != x.shape[-1]: - mask = mask[:, :x.shape[-1], :x.shape[-1]] + if mask is not None: + if len(mask.shape) == 2: + mask = mask.unsqueeze(0).repeat(x.shape[0],1,1) + if mask.shape[1] != x.shape[-1]: + mask = mask[:, :x.shape[-1], :x.shape[-1]] x = x.reshape(b, c, -1) x = self.norm(x) diff --git a/codes/models/audio/music/transformer_diffusion13.py b/codes/models/audio/music/transformer_diffusion13.py index 6cc315ac..b2a252e2 100644 --- a/codes/models/audio/music/transformer_diffusion13.py +++ b/codes/models/audio/music/transformer_diffusion13.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock +from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock from trainer.networks import register_model @@ -29,10 +29,11 @@ class SubBlock(nn.Module): self.attnorm = nn.GroupNorm(8, contraction_dim) self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1) self.ffnorm = nn.GroupNorm(8, contraction_dim) + self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8) def forward(self, x, blk_emb): blk_enc = self.blk_emb_proj(blk_emb) - ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1))) + ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask)) ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x. ah = F.gelu(self.attnorm(ah)) h = torch.cat([ah, x], dim=1) diff --git a/codes/models/audio/music/transformer_diffusion14.py b/codes/models/audio/music/transformer_diffusion14.py index 20c7ba15..26d7eced 100644 --- a/codes/models/audio/music/transformer_diffusion14.py +++ b/codes/models/audio/music/transformer_diffusion14.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from models.arch_util import AttentionBlock, TimestepEmbedSequential +from models.arch_util import AttentionBlock, TimestepEmbedSequential, build_local_attention_mask from models.audio.music.encoders import ResEncoder16x from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear from models.diffusion.unet_diffusion import TimestepBlock @@ -12,32 +12,6 @@ from trainer.networks import register_model from utils.util import checkpoint, print_network -def build_local_attention_mask(n, l, fixed_region): - """ - Builds an attention mask that focuses attention on local region - Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied. - Args: - n: Size of returned matrix (maximum sequence size) - l: Size of local context (uni-directional, e.g. the total context is l*2) - fixed_region: The number of sequence elements at the start of the sequence that get full attention. - Returns: - A mask that can be applied to AttentionBlock to achieve local attention. - """ - assert l*2 < n, f'Local context must be less than global context. {l}, {n}' - o = torch.arange(0,n) - c = o.unsqueeze(-1).repeat(1,n) - r = o.unsqueeze(0).repeat(n,1) - localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1) - localized[:fixed_region] = 1 - localized[:, :fixed_region] = 1 - mask = localized > 0 - return mask - - -def test_local_attention_mask(): - print(build_local_attention_mask(9,4,1)) - - class SubBlock(nn.Module): def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout, enable_attention_masking=False): super().__init__()