This commit is contained in:
James Betker 2022-01-22 08:23:29 -07:00
parent 851070075a
commit 8f48848f91
4 changed files with 8 additions and 8 deletions

View File

@ -11,7 +11,6 @@ import torch.nn.functional as F
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
from scripts.audio.gen.use_diffuse_tts import ceil_multiple from scripts.audio.gen.use_diffuse_tts import ceil_multiple
from trainer.networks import register_model from trainer.networks import register_model
from utils.util import get_mask_from_lengths
from utils.util import checkpoint from utils.util import checkpoint

View File

@ -20,9 +20,9 @@ def ceil_multiple(base, multiple):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='../options/train_diffusion_tts.yml') parser.add_argument('-opt', type=str, help='Path to options YAML file used to train the diffusion model', default='../options/train_diffusion_tts_medium.yml')
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator') parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
parser.add_argument('-diffusion_model_path', type=str, help='Path to saved model weights', default='../experiments/train_diffusion_tts_experimental_fp16\\models\\17800_generator_ema.pth') parser.add_argument('-diffusion_model_path', type=str, help='Path to saved model weights', default='X:\\dlas\\experiments\\train_diffusion_tts_medium\\models\\5200_generator.pth')
parser.add_argument('-aligned_codes', type=str, help='Comma-delimited list of integer codes that defines text & prosody. Get this by apply W2V to an existing audio clip or from a bespoke generator.', parser.add_argument('-aligned_codes', type=str, help='Comma-delimited list of integer codes that defines text & prosody. Get this by apply W2V to an existing audio clip or from a bespoke generator.',
default='0,0,0,0,10,10,0,4,0,7,0,17,4,4,0,25,5,0,13,13,0,22,4,4,0,21,15,15,7,0,0,14,4,4,6,8,4,4,0,0,12,5,0,0,5,0,4,4,22,22,8,16,16,0,4,4,4,0,0,0,0,0,0,0') # Default: 'i am very glad to see you', libritts/train-clean-100/103/1241/103_1241_000017_000001.wav. default='0,0,0,0,10,10,0,4,0,7,0,17,4,4,0,25,5,0,13,13,0,22,4,4,0,21,15,15,7,0,0,14,4,4,6,8,4,4,0,0,12,5,0,0,5,0,4,4,22,22,8,16,16,0,4,4,4,0,0,0,0,0,0,0') # Default: 'i am very glad to see you', libritts/train-clean-100/103/1241/103_1241_000017_000001.wav.
parser.add_argument('-cond', type=str, help='Path to the conditioning input audio file.', default='Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav') parser.add_argument('-cond', type=str, help='Path to the conditioning input audio file.', default='Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav')

View File

@ -82,7 +82,7 @@ if __name__ == '__main__':
'ed_sheeran': ['D:\\data\\audio\\sample_voices\\ed_sheeran.wav'], 'ed_sheeran': ['D:\\data\\audio\\sample_voices\\ed_sheeran.wav'],
'simmons': ['Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav'], 'simmons': ['Y:\\clips\\books1\\754_Dan Simmons - The Rise Of Endymion 356 of 450\\00026.wav'],
'news_girl': ['Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00016.wav'], 'news_girl': ['Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00022.wav', 'Y:\\clips\\podcasts-0\\8288_20210113-Is More Violence Coming_\\00016.wav'],
'dan_carlin': ['Y:\\clips\\books1\\5_dchha06 Shield of the West\\00476.wav'], 'dan_carlin': ['Y:\\clips\\books1\\5_dchha06 Shield of the West\\00476.wav', 'Y:\\clips\\books1\\15_dchha16 Nazi Tidbits\\00036.wav'],
'libri_test': ['Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'], 'libri_test': ['Y:\\libritts\\test-clean\\672\\122797\\672_122797_000057_000002.wav'],
'myself': ['D:\\data\\audio\\sample_voices\\myself1.wav', 'D:\\data\\audio\\sample_voices\\myself2.wav'], 'myself': ['D:\\data\\audio\\sample_voices\\myself1.wav', 'D:\\data\\audio\\sample_voices\\myself2.wav'],
} }
@ -90,7 +90,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt_diffuse', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_vocoder_22k_level.yml') parser.add_argument('-opt_diffuse', type=str, help='Path to options YAML file used to train the diffusion model', default='X:\\dlas\\experiments\\train_diffusion_vocoder_22k_level.yml')
parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator') parser.add_argument('-diffusion_model_name', type=str, help='Name of the diffusion model in opt.', default='generator')
parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_22k_level\\models\\12000_generator_ema.pth') parser.add_argument('-diffusion_model_path', type=str, help='Diffusion model checkpoint to load.', default='X:\\dlas\\experiments\\train_diffusion_vocoder_22k_level\\models\\15000_generator_ema.pth')
parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae') parser.add_argument('-dvae_model_name', type=str, help='Name of the DVAE model in opt.', default='dvae')
parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts_unified.yml') parser.add_argument('-opt_gpt_tts', type=str, help='Path to options YAML file used to train the GPT-TTS model', default='X:\\dlas\\experiments\\train_gpt_tts_unified.yml')
parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt') parser.add_argument('-gpt_tts_model_name', type=str, help='Name of the GPT TTS model in opt.', default='gpt')
@ -131,8 +131,8 @@ if __name__ == '__main__':
print("Performing GPT inference..") print("Performing GPT inference..")
samples = [] samples = []
for b in tqdm(range(args.num_batches)): for b in tqdm(range(args.num_batches)):
codes = gpt.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=20, top_p=.95, codes = gpt.inference_speech(conds, text, num_beams=1, repetition_penalty=1.0, do_sample=True, top_k=50, top_p=.95,
num_return_sequences=args.num_samples//args.num_batches, length_penalty=1) temperature=.9, num_return_sequences=args.num_samples//args.num_batches, length_penalty=1)
padding_needed = 250 - codes.shape[1] padding_needed = 250 - codes.shape[1]
codes = F.pad(codes, (0, padding_needed), value=stop_mel_token) codes = F.pad(codes, (0, padding_needed), value=stop_mel_token)
samples.append(codes) samples.append(codes)

View File

@ -78,6 +78,7 @@ class GeneratorInjector(Injector):
self.grad = opt['grad'] if 'grad' in opt.keys() else True self.grad = opt['grad'] if 'grad' in opt.keys() else True
self.method = opt_get(opt, ['method'], None) # If specified, this method is called instead of __call__() self.method = opt_get(opt, ['method'], None) # If specified, this method is called instead of __call__()
self.args = opt_get(opt, ['args'], {}) self.args = opt_get(opt, ['args'], {})
self.fp16_override = opt_get(opt, ['fp16'], True)
def forward(self, state): def forward(self, state):
gen = self.env['generators'][self.opt['generator']] gen = self.env['generators'][self.opt['generator']]
@ -86,7 +87,7 @@ class GeneratorInjector(Injector):
gen = gen.module # Dereference DDP wrapper. gen = gen.module # Dereference DDP wrapper.
method = gen if self.method is None else getattr(gen, self.method) method = gen if self.method is None else getattr(gen, self.method)
with autocast(enabled=self.env['opt']['fp16']): with autocast(enabled=self.env['opt']['fp16'] and self.fp16_override):
if isinstance(self.input, list): if isinstance(self.input, list):
params = extract_params_from_state(self.input, state) params = extract_params_from_state(self.input, state)
else: else: