diff --git a/codes/models/gpt_voice/unet_diffusion_tts7.py b/codes/models/gpt_voice/unet_diffusion_tts7.py index 77b76a36..90bc12af 100644 --- a/codes/models/gpt_voice/unet_diffusion_tts7.py +++ b/codes/models/gpt_voice/unet_diffusion_tts7.py @@ -18,6 +18,26 @@ from utils.util import checkpoint from x_transformers import Encoder, ContinuousTransformerWrapper +def clustered_mask(probability, shape, dev, lateral_expansion_radius_max=3, inverted=False): + """ + Produces a masking vector of the specified shape where each element has probability to be zero. + lateral_expansion_radius_max neighbors of any element that is zero also have a 50% chance to be zero. + Effectively, this produces clusters of masks tending to be lateral_expansion_radius_max wide. + """ + # Each masked token spreads out to 1+lateral_expansion_radius_max on average, therefore reduce the probability in + # kind + probability = probability / (1+lateral_expansion_radius_max) + + mask = torch.rand(shape, device=dev) + mask = (mask < probability).float() + kernel = torch.tensor([.5 for _ in range(lateral_expansion_radius_max)] + [1] + [.5 for _ in range(lateral_expansion_radius_max)], device=dev) + mask = F.conv1d(mask.unsqueeze(1), kernel.view(1,1,2*lateral_expansion_radius_max+1), padding=lateral_expansion_radius_max).squeeze(1) + if inverted: + return torch.bernoulli(torch.clamp(mask, 0, 1)) != 0 + else: + return torch.bernoulli(torch.clamp(mask, 0, 1)) == 0 + + class CheckpointedLayer(nn.Module): """ Wraps a module. When forward() is called, passes kwargs that require_grad through torch.checkpoint() and bypasses @@ -118,7 +138,6 @@ class ResBlock(TimestepBlock): h = self.out_layers(h) return self.skip_connection(x) + h - class DiffusionTts(nn.Module): """ The full UNet model with attention and timestep embedding. @@ -428,7 +447,7 @@ class DiffusionTts(nn.Module): if tokens is not None: # Mask out guidance tokens for un-guided diffusion. if self.training and self.nil_guidance_fwd_proportion > 0: - token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion + token_mask = clustered_mask(self.nil_guidance_fwd_proportion, tokens.shape, tokens.device, inverted=True) tokens = torch.where(token_mask, self.mask_token_id, tokens) code_emb = self.code_embedding(tokens).permute(0,2,1) cond_emb = cond_emb.unsqueeze(-1).repeat(1,1,code_emb.shape[-1])