From 8c255811adc8c184f85d0572760d2770c1ad7ecb Mon Sep 17 00:00:00 2001
From: James Betker <jbetker@gmail.com>
Date: Tue, 25 Jan 2022 17:57:16 -0700
Subject: [PATCH] more fixes

---
 codes/data/audio/fast_paired_dataset.py                   | 2 +-
 codes/models/gpt_voice/unet_diffusion_tts_experimental.py | 2 +-
 codes/train.py                                            | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/codes/data/audio/fast_paired_dataset.py b/codes/data/audio/fast_paired_dataset.py
index e1be1bbb..638ec81a 100644
--- a/codes/data/audio/fast_paired_dataset.py
+++ b/codes/data/audio/fast_paired_dataset.py
@@ -54,9 +54,9 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
         self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
         self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
         self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
-        self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
         self.text_cleaners = hparams.text_cleaners
         self.sample_rate = hparams.sample_rate
+        self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
         self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
         if self.max_wav_len is not None:
             self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
diff --git a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py
index 6e2cd9d8..1074b706 100644
--- a/codes/models/gpt_voice/unet_diffusion_tts_experimental.py
+++ b/codes/models/gpt_voice/unet_diffusion_tts_experimental.py
@@ -334,7 +334,7 @@ class DiffusionTts(nn.Module):
             emb = emb1
 
         # Mask out guidance tokens for un-guided diffusion.
-        if self.nil_guidance_fwd_proportion > 0:
+        if self.training and self.nil_guidance_fwd_proportion > 0:
             token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
             tokens = torch.where(token_mask, self.mask_token_id, tokens)
 
diff --git a/codes/train.py b/codes/train.py
index 812b4b56..9a2e8a80 100644
--- a/codes/train.py
+++ b/codes/train.py
@@ -299,7 +299,7 @@ class Trainer:
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
-    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts_medium.yml')
+    parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml')
     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
     parser.add_argument('--local_rank', type=int, default=0)
     args = parser.parse_args()