Use clustered masking in udtts7
This commit is contained in:
parent
7c17c8e674
commit
ea500ad42a
|
@ -18,6 +18,26 @@ from utils.util import checkpoint
|
|||
from x_transformers import Encoder, ContinuousTransformerWrapper
|
||||
|
||||
|
||||
def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3, inverted=False):
|
||||
"""
|
||||
Produces a masking vector of the specified shape where each element has probability to be zero.
|
||||
lateral_expansion_radius_max neighbors of any element that is zero also have a 50% chance to be zero.
|
||||
Effectively, this produces clusters of masks tending to be lateral_expansion_radius_max wide.
|
||||
"""
|
||||
# Each masked token spreads out to 1+lateral_expansion_radius_max on average, therefore reduce the probability in
|
||||
# kind
|
||||
probability = probability / (1+lateral_expansion_radius_max)
|
||||
|
||||
mask = torch.rand(shape, device=dev)
|
||||
mask = (mask < probability).float()
|
||||
kernel = torch.tensor([.5 for _ in range(lateral_expansion_radius_max)] + [1] + [.5 for _ in range(lateral_expansion_radius_max)], device=dev)
|
||||
mask = F.conv1d(mask.unsqueeze(1), kernel.view(1,1,2*lateral_expansion_radius_max+1), padding=lateral_expansion_radius_max).squeeze(1)
|
||||
if inverted:
|
||||
return torch.bernoulli(torch.clamp(mask, 0, 1)) != 0
|
||||
else:
|
||||
return torch.bernoulli(torch.clamp(mask, 0, 1)) == 0
|
||||
|
||||
|
||||
class CheckpointedLayer(nn.Module):
|
||||
"""
|
||||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||
|
@ -118,7 +138,6 @@ class ResBlock(TimestepBlock):
|
|||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class DiffusionTts(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
|
@ -428,7 +447,7 @@ class DiffusionTts(nn.Module):
|
|||
if tokens is not None:
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
|
||||
token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True)
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||
|
|
Loading…
Reference in New Issue
Block a user