scope attention in tfd13 as well

This commit is contained in:
James Betker 2022-07-19 14:59:43 -06:00
parent b157b28c7b
commit c00398e955
3 changed files with 35 additions and 33 deletions

View File

@ -439,6 +439,32 @@ class ResBlock(nn.Module):
return self.skip_connection(x) + h
def build_local_attention_mask(n, l, fixed_region):
"""
Builds an attention mask that focuses attention on local region
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
Args:
n: Size of returned matrix (maximum sequence size)
l: Size of local context (uni-directional, e.g. the total context is l*2)
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
Returns:
A mask that can be applied to AttentionBlock to achieve local attention.
"""
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
o = torch.arange(0,n)
c = o.unsqueeze(-1).repeat(1,n)
r = o.unsqueeze(0).repeat(n,1)
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
localized[:fixed_region] = 1
localized[:, :fixed_region] = 1
mask = localized > 0
return mask
def test_local_attention_mask():
print(build_local_attention_mask(9,4,1))
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
@ -492,10 +518,11 @@ class AttentionBlock(nn.Module):
def _forward(self, x, mask=None):
b, c, *spatial = x.shape
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
if mask.shape[1] != x.shape[-1]:
mask = mask[:, :x.shape[-1], :x.shape[-1]]
if mask is not None:
if len(mask.shape) == 2:
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
if mask.shape[1] != x.shape[-1]:
mask = mask[:, :x.shape[-1], :x.shape[-1]]
x = x.reshape(b, c, -1)
x = self.norm(x)

View File

@ -5,7 +5,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
from trainer.networks import register_model
@ -29,10 +29,11 @@ class SubBlock(nn.Module):
self.attnorm = nn.GroupNorm(8, contraction_dim)
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
self.ffnorm = nn.GroupNorm(8, contraction_dim)
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
def forward(self, x, blk_emb):
blk_enc = self.blk_emb_proj(blk_emb)
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1)))
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
ah = F.gelu(self.attnorm(ah))
h = torch.cat([ah, x], dim=1)

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from models.arch_util import AttentionBlock, TimestepEmbedSequential
from models.arch_util import AttentionBlock, TimestepEmbedSequential, build_local_attention_mask
from models.audio.music.encoders import ResEncoder16x
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
from models.diffusion.unet_diffusion import TimestepBlock
@ -12,32 +12,6 @@ from trainer.networks import register_model
from utils.util import checkpoint, print_network
def build_local_attention_mask(n, l, fixed_region):
"""
Builds an attention mask that focuses attention on local region
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
Args:
n: Size of returned matrix (maximum sequence size)
l: Size of local context (uni-directional, e.g. the total context is l*2)
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
Returns:
A mask that can be applied to AttentionBlock to achieve local attention.
"""
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
o = torch.arange(0,n)
c = o.unsqueeze(-1).repeat(1,n)
r = o.unsqueeze(0).repeat(n,1)
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
localized[:fixed_region] = 1
localized[:, :fixed_region] = 1
mask = localized > 0
return mask
def test_local_attention_mask():
print(build_local_attention_mask(9,4,1))
class SubBlock(nn.Module):
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout, enable_attention_masking=False):
super().__init__()