diff --git a/codes/models/audio/tts/unet_diffusion_tts9.py b/codes/models/audio/tts/unet_diffusion_tts9.py index 243feab3..3ea0da14 100644 --- a/codes/models/audio/tts/unet_diffusion_tts9.py +++ b/codes/models/audio/tts/unet_diffusion_tts9.py @@ -388,6 +388,23 @@ class DiffusionTts(nn.Module): } return groups + def fix_alignment(self, x, aligned_conditioning): + """ + The UNet requires that the input is a certain multiple of 2, defined by the UNet depth. Enforce this by + padding both and before forward propagation and removing the padding before returning. + """ + cm = ceil_multiple(x.shape[-1], self.alignment_size) + if cm != 0: + pc = (cm-x.shape[-1])/x.shape[-1] + x = F.pad(x, (0,cm-x.shape[-1])) + # Also fix aligned_latent, which is aligned to x. + if is_latent(aligned_conditioning): + aligned_conditioning = torch.cat([aligned_conditioning, + self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1) + else: + aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) + return x, aligned_conditioning + def forward(self, x, timesteps, aligned_conditioning, conditioning_input, lr_input=None, conditioning_free=False): """ Apply the model to an input batch. @@ -415,16 +432,7 @@ class DiffusionTts(nn.Module): # Fix input size to the proper multiple of 2 so we don't get alignment errors going down and back up the U-net. orig_x_shape = x.shape[-1] - cm = ceil_multiple(x.shape[-1], self.alignment_size) - if cm != 0: - pc = (cm-x.shape[-1])/x.shape[-1] - x = F.pad(x, (0,cm-x.shape[-1])) - # Also fix aligned_latent, which is aligned to x. - if is_latent(aligned_conditioning): - aligned_conditioning = torch.cat([aligned_conditioning, - self.aligned_latent_padding_embedding.repeat(x.shape[0], 1, int(pc * aligned_conditioning.shape[-1]))], dim=-1) - else: - aligned_conditioning = F.pad(aligned_conditioning, (0,int(pc*aligned_conditioning.shape[-1]))) + x, aligned_conditioning = self.fix_alignment(x, aligned_conditioning) with autocast(x.device.type, enabled=self.enable_fp16): diff --git a/codes/trainer/ExtensibleTrainer.py b/codes/trainer/ExtensibleTrainer.py index 8c297152..706559e8 100644 --- a/codes/trainer/ExtensibleTrainer.py +++ b/codes/trainer/ExtensibleTrainer.py @@ -314,6 +314,11 @@ class ExtensibleTrainer(BaseModel): if hasattr(net.module, "before_step"): net.module.before_step(it) + # Unscale gradients within the step. (This is admittedly pretty messy but the API contract between step & ET is pretty much broken at this point) + # This is needed to accurately log the grad norms. + for opt in step.optimizers: + step.scaler.unscale_(opt) + if return_grad_norms and train_step: for name in nets_to_train: model = self.networks[name] diff --git a/codes/trainer/eval/audio_diffusion_fid.py b/codes/trainer/eval/audio_diffusion_fid.py index dcd8baab..d10ce903 100644 --- a/codes/trainer/eval/audio_diffusion_fid.py +++ b/codes/trainer/eval/audio_diffusion_fid.py @@ -125,7 +125,7 @@ class AudioDiffusionFid(evaluator.Evaluator): mel = wav_to_mel(audio) mel_codes = convert_mel_to_codes(self.local_modules['dvae'], mel) real_resampled = torchaudio.functional.resample(audio, 22050, SAMPLE_RATE).unsqueeze(0) - univnet_mel = wav_to_univnet_mel(audio, mel_norms_file=None) # to be used for a conditioning input + univnet_mel = wav_to_univnet_mel(real_resampled, do_normalization=False) # to be used for a conditioning input, but also guides output shape. output_size = univnet_mel.shape[-1] aligned_codes_compression_factor = output_size // mel_codes.shape[-1] diff --git a/codes/trainer/steps.py b/codes/trainer/steps.py index 12897f0b..95a9fa95 100644 --- a/codes/trainer/steps.py +++ b/codes/trainer/steps.py @@ -304,7 +304,7 @@ class ConfigurableStep(Module): return self.grads_generated = False for opt in self.optimizers: - self.scaler.unscale_(opt) + # self.scaler.unscale_(opt) It would be important to do this here, but ExtensibleTrainer currently does it. # Optimizers can be opted out in the early stages of training. after = opt._config['after'] if 'after' in opt._config.keys() else 0