scope attention in tfd13 as well
This commit is contained in:
parent
b157b28c7b
commit
c00398e955
|
@ -439,6 +439,32 @@ class ResBlock(nn.Module):
|
|||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
def build_local_attention_mask(n, l, fixed_region):
|
||||
"""
|
||||
Builds an attention mask that focuses attention on local region
|
||||
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
||||
Args:
|
||||
n: Size of returned matrix (maximum sequence size)
|
||||
l: Size of local context (uni-directional, e.g. the total context is l*2)
|
||||
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
|
||||
Returns:
|
||||
A mask that can be applied to AttentionBlock to achieve local attention.
|
||||
"""
|
||||
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
||||
o = torch.arange(0,n)
|
||||
c = o.unsqueeze(-1).repeat(1,n)
|
||||
r = o.unsqueeze(0).repeat(n,1)
|
||||
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
|
||||
localized[:fixed_region] = 1
|
||||
localized[:, :fixed_region] = 1
|
||||
mask = localized > 0
|
||||
return mask
|
||||
|
||||
|
||||
def test_local_attention_mask():
|
||||
print(build_local_attention_mask(9,4,1))
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
|
@ -492,10 +518,11 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
def _forward(self, x, mask=None):
|
||||
b, c, *spatial = x.shape
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
|
||||
if mask.shape[1] != x.shape[-1]:
|
||||
mask = mask[:, :x.shape[-1], :x.shape[-1]]
|
||||
if mask is not None:
|
||||
if len(mask.shape) == 2:
|
||||
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
|
||||
if mask.shape[1] != x.shape[-1]:
|
||||
mask = mask[:, :x.shape[-1], :x.shape[-1]]
|
||||
|
||||
x = x.reshape(b, c, -1)
|
||||
x = self.norm(x)
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
|
||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
from trainer.networks import register_model
|
||||
|
@ -29,10 +29,11 @@ class SubBlock(nn.Module):
|
|||
self.attnorm = nn.GroupNorm(8, contraction_dim)
|
||||
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
||||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
||||
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
|
||||
|
||||
def forward(self, x, blk_emb):
|
||||
blk_enc = self.blk_emb_proj(blk_emb)
|
||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1)))
|
||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
||||
ah = F.gelu(self.attnorm(ah))
|
||||
h = torch.cat([ah, x], dim=1)
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.arch_util import AttentionBlock, TimestepEmbedSequential
|
||||
from models.arch_util import AttentionBlock, TimestepEmbedSequential, build_local_attention_mask
|
||||
from models.audio.music.encoders import ResEncoder16x
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import TimestepBlock
|
||||
|
@ -12,32 +12,6 @@ from trainer.networks import register_model
|
|||
from utils.util import checkpoint, print_network
|
||||
|
||||
|
||||
def build_local_attention_mask(n, l, fixed_region):
|
||||
"""
|
||||
Builds an attention mask that focuses attention on local region
|
||||
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
||||
Args:
|
||||
n: Size of returned matrix (maximum sequence size)
|
||||
l: Size of local context (uni-directional, e.g. the total context is l*2)
|
||||
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
|
||||
Returns:
|
||||
A mask that can be applied to AttentionBlock to achieve local attention.
|
||||
"""
|
||||
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
||||
o = torch.arange(0,n)
|
||||
c = o.unsqueeze(-1).repeat(1,n)
|
||||
r = o.unsqueeze(0).repeat(n,1)
|
||||
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
|
||||
localized[:fixed_region] = 1
|
||||
localized[:, :fixed_region] = 1
|
||||
mask = localized > 0
|
||||
return mask
|
||||
|
||||
|
||||
def test_local_attention_mask():
|
||||
print(build_local_attention_mask(9,4,1))
|
||||
|
||||
|
||||
class SubBlock(nn.Module):
|
||||
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout, enable_attention_masking=False):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue
Block a user