Use clustered masking in udtts7
This commit is contained in:
parent
7c17c8e674
commit
ea500ad42a
|
@ -18,6 +18,26 @@ from utils.util import checkpoint
|
||||||
from x_transformers import Encoder, ContinuousTransformerWrapper
|
from x_transformers import Encoder, ContinuousTransformerWrapper
|
||||||
|
|
||||||
|
|
||||||
|
def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3, inverted=False):
|
||||||
|
"""
|
||||||
|
Produces a masking vector of the specified shape where each element has probability to be zero.
|
||||||
|
lateral_expansion_radius_max neighbors of any element that is zero also have a 50% chance to be zero.
|
||||||
|
Effectively, this produces clusters of masks tending to be lateral_expansion_radius_max wide.
|
||||||
|
"""
|
||||||
|
# Each masked token spreads out to 1+lateral_expansion_radius_max on average, therefore reduce the probability in
|
||||||
|
# kind
|
||||||
|
probability = probability / (1+lateral_expansion_radius_max)
|
||||||
|
|
||||||
|
mask = torch.rand(shape, device=dev)
|
||||||
|
mask = (mask < probability).float()
|
||||||
|
kernel = torch.tensor([.5 for _ in range(lateral_expansion_radius_max)] + [1] + [.5 for _ in range(lateral_expansion_radius_max)], device=dev)
|
||||||
|
mask = F.conv1d(mask.unsqueeze(1), kernel.view(1,1,2*lateral_expansion_radius_max+1), padding=lateral_expansion_radius_max).squeeze(1)
|
||||||
|
if inverted:
|
||||||
|
return torch.bernoulli(torch.clamp(mask, 0, 1)) != 0
|
||||||
|
else:
|
||||||
|
return torch.bernoulli(torch.clamp(mask, 0, 1)) == 0
|
||||||
|
|
||||||
|
|
||||||
class CheckpointedLayer(nn.Module):
|
class CheckpointedLayer(nn.Module):
|
||||||
"""
|
"""
|
||||||
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses
|
||||||
|
@ -118,7 +138,6 @@ class ResBlock(TimestepBlock):
|
||||||
h = self.out_layers(h)
|
h = self.out_layers(h)
|
||||||
return self.skip_connection(x) + h
|
return self.skip_connection(x) + h
|
||||||
|
|
||||||
|
|
||||||
class DiffusionTts(nn.Module):
|
class DiffusionTts(nn.Module):
|
||||||
"""
|
"""
|
||||||
The full UNet model with attention and timestep embedding.
|
The full UNet model with attention and timestep embedding.
|
||||||
|
@ -428,7 +447,7 @@ class DiffusionTts(nn.Module):
|
||||||
if tokens is not None:
|
if tokens is not None:
|
||||||
# Mask out guidance tokens for un-guided diffusion.
|
# Mask out guidance tokens for un-guided diffusion.
|
||||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||||
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
|
token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True)
|
||||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||||
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
code_emb = self.code_embedding(tokens).permute(0,2,1)
|
||||||
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user