Use clustered masking in udtts7

This commit is contained in:
James Betker 2022-02-24 07:57:26 -07:00
parent 7c17c8e674
commit ea500ad42a

View File

@ -18,6 +18,26 @@ from utils.util import checkpoint
from x_transformers import Encoder, ContinuousTransformerWrapper
def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3, inverted=False):
"""
Produces a masking vector of the specified shape where each element has probability to be zero.
lateral_expansion_radius_max neighbors of any element that is zero also have a 50% chance to be zero.
Effectively, this produces clusters of masks tending to be lateral_expansion_radius_max wide.
"""
# Each masked token spreads out to 1+lateral_expansion_radius_max on average, therefore reduce the probability in
# kind
probability = probability / (1+lateral_expansion_radius_max)
mask = torch.rand(shape, device=dev)
mask = (mask < probability).float()
kernel = torch.tensor([.5 for _ in range(lateral_expansion_radius_max)] + [1] + [.5 for _ in range(lateral_expansion_radius_max)], device=dev)
mask = F.conv1d(mask.unsqueeze(1), kernel.view(1,1,2*lateral_expansion_radius_max+1), padding=lateral_expansion_radius_max).squeeze(1)
if inverted:
return torch.bernoulli(torch.clamp(mask, 0, 1)) != 0
else:
return torch.bernoulli(torch.clamp(mask, 0, 1)) == 0
class CheckpointedLayer(nn.Module):
"""
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
@ -118,7 +138,6 @@ class ResBlock(TimestepBlock):
h = self.out_layers(h)
return self.skip_connection(x) + h
class DiffusionTts(nn.Module):
"""
The full UNet model with attention and timestep embedding.
@ -428,7 +447,7 @@ class DiffusionTts(nn.Module):
if tokens is not None:
# Mask out guidance tokens for un-guided diffusion.
if self.training and self.nil_guidance_fwd_proportion > 0:
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True)
tokens = torch.where(token_mask, self.mask_token_id, tokens)
code_emb = self.code_embedding(tokens).permute(0,2,1)
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])