diff --git a/codes/models/gpt_voice/ctc_code_generator.py b/codes/models/gpt_voice/ctc_code_generator.py index e216807d..d45e993f 100644 --- a/codes/models/gpt_voice/ctc_code_generator.py +++ b/codes/models/gpt_voice/ctc_code_generator.py @@ -78,7 +78,8 @@ class CtcCodeGenerator(nn.Module): loss_mask = torch.ones_like(codes) for i, l in enumerate(unpadded_lengths): loss_mask[i, l:] = 0 - codes = clustered_mask(self.mask_probability, codes.shape, codes.device) * codes + if self.training: + codes = clustered_mask(self.mask_probability, codes.shape, codes.device) * codes if separators.max() > self.max_pad: print(f"Got unexpectedly long separators. Max: {separators.max()}, {separators}")