forked from mrq/DL-Art-School
scope attention in tfd13 as well
This commit is contained in:
parent
b157b28c7b
commit
c00398e955
|
@ -439,6 +439,32 @@ class ResBlock(nn.Module):
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
|
def build_local_attention_mask(n, l, fixed_region):
|
||||||
|
"""
|
||||||
|
Builds an attention mask that focuses attention on local region
|
||||||
|
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
||||||
|
Args:
|
||||||
|
n: Size of returned matrix (maximum sequence size)
|
||||||
|
l: Size of local context (uni-directional, e.g. the total context is l*2)
|
||||||
|
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
|
||||||
|
Returns:
|
||||||
|
A mask that can be applied to AttentionBlock to achieve local attention.
|
||||||
|
"""
|
||||||
|
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
||||||
|
o = torch.arange(0,n)
|
||||||
|
c = o.unsqueeze(-1).repeat(1,n)
|
||||||
|
r = o.unsqueeze(0).repeat(n,1)
|
||||||
|
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
|
||||||
|
localized[:fixed_region] = 1
|
||||||
|
localized[:, :fixed_region] = 1
|
||||||
|
mask = localized > 0
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_attention_mask():
|
||||||
|
print(build_local_attention_mask(9,4,1))
|
||||||
|
|
||||||
|
|
||||||
class AttentionBlock(nn.Module):
|
class AttentionBlock(nn.Module):
|
||||||
"""
|
"""
|
||||||
An attention block that allows spatial positions to attend to each other.
|
An attention block that allows spatial positions to attend to each other.
|
||||||
|
@ -492,10 +518,11 @@ class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
def _forward(self, x, mask=None):
|
def _forward(self, x, mask=None):
|
||||||
b, c, *spatial = x.shape
|
b, c, *spatial = x.shape
|
||||||
if len(mask.shape) == 2:
|
if mask is not None:
|
||||||
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
|
if len(mask.shape) == 2:
|
||||||
if mask.shape[1] != x.shape[-1]:
|
mask = mask.unsqueeze(0).repeat(x.shape[0],1,1)
|
||||||
mask = mask[:, :x.shape[-1], :x.shape[-1]]
|
if mask.shape[1] != x.shape[-1]:
|
||||||
|
mask = mask[:, :x.shape[-1], :x.shape[-1]]
|
||||||
|
|
||||||
x = x.reshape(b, c, -1)
|
x = x.reshape(b, c, -1)
|
||||||
x = self.norm(x)
|
x = self.norm(x)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock
|
from models.arch_util import ResBlock, TimestepEmbedSequential, AttentionBlock, build_local_attention_mask
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepBlock
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
@ -29,10 +29,11 @@ class SubBlock(nn.Module):
|
||||||
self.attnorm = nn.GroupNorm(8, contraction_dim)
|
self.attnorm = nn.GroupNorm(8, contraction_dim)
|
||||||
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
self.ff = nn.Conv1d(inp_dim+contraction_dim, contraction_dim, kernel_size=3, padding=1)
|
||||||
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
self.ffnorm = nn.GroupNorm(8, contraction_dim)
|
||||||
|
self.mask = build_local_attention_mask(n=4000, l=64, fixed_region=8)
|
||||||
|
|
||||||
def forward(self, x, blk_emb):
|
def forward(self, x, blk_emb):
|
||||||
blk_enc = self.blk_emb_proj(blk_emb)
|
blk_enc = self.blk_emb_proj(blk_emb)
|
||||||
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1)))
|
ah = self.dropout(self.attn(torch.cat([blk_enc, x], dim=-1), mask=self.mask))
|
||||||
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
ah = ah[:,:,blk_emb.shape[-1]:] # Strip off the blk_emb and re-align with x.
|
||||||
ah = F.gelu(self.attnorm(ah))
|
ah = F.gelu(self.attnorm(ah))
|
||||||
h = torch.cat([ah, x], dim=1)
|
h = torch.cat([ah, x], dim=1)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from models.arch_util import AttentionBlock, TimestepEmbedSequential
|
from models.arch_util import AttentionBlock, TimestepEmbedSequential, build_local_attention_mask
|
||||||
from models.audio.music.encoders import ResEncoder16x
|
from models.audio.music.encoders import ResEncoder16x
|
||||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||||
from models.diffusion.unet_diffusion import TimestepBlock
|
from models.diffusion.unet_diffusion import TimestepBlock
|
||||||
|
@ -12,32 +12,6 @@ from trainer.networks import register_model
|
||||||
from utils.util import checkpoint, print_network
|
from utils.util import checkpoint, print_network
|
||||||
|
|
||||||
|
|
||||||
def build_local_attention_mask(n, l, fixed_region):
|
|
||||||
"""
|
|
||||||
Builds an attention mask that focuses attention on local region
|
|
||||||
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
|
||||||
Args:
|
|
||||||
n: Size of returned matrix (maximum sequence size)
|
|
||||||
l: Size of local context (uni-directional, e.g. the total context is l*2)
|
|
||||||
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
|
|
||||||
Returns:
|
|
||||||
A mask that can be applied to AttentionBlock to achieve local attention.
|
|
||||||
"""
|
|
||||||
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
|
||||||
o = torch.arange(0,n)
|
|
||||||
c = o.unsqueeze(-1).repeat(1,n)
|
|
||||||
r = o.unsqueeze(0).repeat(n,1)
|
|
||||||
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
|
|
||||||
localized[:fixed_region] = 1
|
|
||||||
localized[:, :fixed_region] = 1
|
|
||||||
mask = localized > 0
|
|
||||||
return mask
|
|
||||||
|
|
||||||
|
|
||||||
def test_local_attention_mask():
|
|
||||||
print(build_local_attention_mask(9,4,1))
|
|
||||||
|
|
||||||
|
|
||||||
class SubBlock(nn.Module):
|
class SubBlock(nn.Module):
|
||||||
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout, enable_attention_masking=False):
|
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout, enable_attention_masking=False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user