more fixes

This commit is contained in:
James Betker 2022-01-25 17:57:16 -07:00
parent 0f3ca28e39
commit 8c255811ad
3 changed files with 3 additions and 3 deletions

View File

@ -54,9 +54,9 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
self.text_cleaners = hparams.text_cleaners
self.sample_rate = hparams.sample_rate
self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
if self.max_wav_len is not None:
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio

View File

@ -334,7 +334,7 @@ class DiffusionTts(nn.Module):
emb = emb1
# Mask out guidance tokens for un-guided diffusion.
if self.nil_guidance_fwd_proportion > 0:
if self.training and self.nil_guidance_fwd_proportion > 0:
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
tokens = torch.where(token_mask, self.mask_token_id, tokens)

View File

@ -299,7 +299,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts_medium.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()