|
|
|
@ -4,7 +4,7 @@ import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from models.arch_util import AttentionBlock, TimestepEmbedSequential
|
|
|
|
|
from models.arch_util import AttentionBlock, TimestepEmbedSequential, build_local_attention_mask
|
|
|
|
|
from models.audio.music.encoders import ResEncoder16x
|
|
|
|
|
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
|
|
|
|
from models.diffusion.unet_diffusion import TimestepBlock
|
|
|
|
@ -12,32 +12,6 @@ from trainer.networks import register_model
|
|
|
|
|
from utils.util import checkpoint, print_network
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_local_attention_mask(n, l, fixed_region):
|
|
|
|
|
"""
|
|
|
|
|
Builds an attention mask that focuses attention on local region
|
|
|
|
|
Includes provisions for a "fixed_region" at the start of the sequence where full attention weights will be applied.
|
|
|
|
|
Args:
|
|
|
|
|
n: Size of returned matrix (maximum sequence size)
|
|
|
|
|
l: Size of local context (uni-directional, e.g. the total context is l*2)
|
|
|
|
|
fixed_region: The number of sequence elements at the start of the sequence that get full attention.
|
|
|
|
|
Returns:
|
|
|
|
|
A mask that can be applied to AttentionBlock to achieve local attention.
|
|
|
|
|
"""
|
|
|
|
|
assert l*2 < n, f'Local context must be less than global context. {l}, {n}'
|
|
|
|
|
o = torch.arange(0,n)
|
|
|
|
|
c = o.unsqueeze(-1).repeat(1,n)
|
|
|
|
|
r = o.unsqueeze(0).repeat(n,1)
|
|
|
|
|
localized = ((-(r-c).abs())+l).clamp(0,l-1) / (l-1)
|
|
|
|
|
localized[:fixed_region] = 1
|
|
|
|
|
localized[:, :fixed_region] = 1
|
|
|
|
|
mask = localized > 0
|
|
|
|
|
return mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_local_attention_mask():
|
|
|
|
|
print(build_local_attention_mask(9,4,1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SubBlock(nn.Module):
|
|
|
|
|
def __init__(self, inp_dim, contraction_dim, blk_dim, heads, dropout, enable_attention_masking=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|