more fixes
This commit is contained in:
parent
0f3ca28e39
commit
8c255811ad
|
@ -54,9 +54,9 @@ class FastPairedVoiceDataset(torch.utils.data.Dataset):
|
|||
self.conditioning_candidates = opt_get(hparams, ['num_conditioning_candidates'], 1)
|
||||
self.conditioning_length = opt_get(hparams, ['conditioning_length'], 44100)
|
||||
self.debug_failures = opt_get(hparams, ['debug_loading_failures'], False)
|
||||
self.aligned_codes_to_audio_ratio = opt_get(hparams, ['aligned_codes_ratio'], 443)
|
||||
self.text_cleaners = hparams.text_cleaners
|
||||
self.sample_rate = hparams.sample_rate
|
||||
self.aligned_codes_to_audio_ratio = 443 * self.sample_rate // 22050
|
||||
self.max_wav_len = opt_get(hparams, ['max_wav_length'], None)
|
||||
if self.max_wav_len is not None:
|
||||
self.max_aligned_codes = self.max_wav_len // self.aligned_codes_to_audio_ratio
|
||||
|
|
|
@ -334,7 +334,7 @@ class DiffusionTts(nn.Module):
|
|||
emb = emb1
|
||||
|
||||
# Mask out guidance tokens for un-guided diffusion.
|
||||
if self.nil_guidance_fwd_proportion > 0:
|
||||
if self.training and self.nil_guidance_fwd_proportion > 0:
|
||||
token_mask = torch.rand(tokens.shape, device=tokens.device) < self.nil_guidance_fwd_proportion
|
||||
tokens = torch.where(token_mask, self.mask_token_id, tokens)
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts_medium.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../experiments/train_diffusion_tts_experimental_fp16/train_diffusion_tts.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user