forked from mrq/tortoise-tts
do not reload AR/vocoder if already loaded
This commit is contained in:
parent
e2db36af60
commit
26133c2031
|
@ -279,14 +279,16 @@ class TextToSpeech:
|
||||||
|
|
||||||
self.tokenizer = VoiceBpeTokenizer()
|
self.tokenizer = VoiceBpeTokenizer()
|
||||||
|
|
||||||
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', models_dir)
|
|
||||||
|
|
||||||
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
|
if os.path.exists(f'{models_dir}/autoregressive.ptt'):
|
||||||
# Assume this is a traced directory.
|
# Assume this is a traced directory.
|
||||||
self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
|
self.autoregressive = torch.jit.load(f'{models_dir}/autoregressive.ptt')
|
||||||
self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
|
self.diffusion = torch.jit.load(f'{models_dir}/diffusion_decoder.ptt')
|
||||||
else:
|
else:
|
||||||
self.load_autoregressive_model(self.autoregressive_model_path)
|
if not autoregressive_model_path or not os.path.exists(autoregressive_model_path):
|
||||||
|
autoregressive_model_path = get_model_path('autoregressive.pth', models_dir)
|
||||||
|
|
||||||
|
self.load_autoregressive_model(autoregressive_model_path)
|
||||||
|
|
||||||
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
self.diffusion = DiffusionTts(model_channels=1024, num_layers=10, in_channels=100, out_channels=200,
|
||||||
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
in_latent_channels=1024, in_tokens=8193, dropout=0, use_fp16=False, num_heads=16,
|
||||||
|
@ -316,11 +318,14 @@ class TextToSpeech:
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
|
||||||
def load_autoregressive_model(self, autoregressive_model_path):
|
def load_autoregressive_model(self, autoregressive_model_path):
|
||||||
|
if hasattr(self,"autoregressive_model_path") and self.autoregressive_model_path == autoregressive_model_path:
|
||||||
|
return
|
||||||
|
|
||||||
self.loading = True
|
self.loading = True
|
||||||
|
|
||||||
previous_path = self.autoregressive_model_path
|
|
||||||
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
|
self.autoregressive_model_path = autoregressive_model_path if autoregressive_model_path and os.path.exists(autoregressive_model_path) else get_model_path('autoregressive.pth', self.models_dir)
|
||||||
self.autoregressive_model_hash = hash_file(self.autoregressive_model_path)
|
self.autoregressive_model_hash = hash_file(self.autoregressive_model_path)
|
||||||
|
print(f"Loading autoregressive model: {self.autoregressive_model_path}")
|
||||||
|
|
||||||
if hasattr(self, 'autoregressive'):
|
if hasattr(self, 'autoregressive'):
|
||||||
del self.autoregressive
|
del self.autoregressive
|
||||||
|
@ -335,9 +340,14 @@ class TextToSpeech:
|
||||||
self.autoregressive = self.autoregressive.to(self.device)
|
self.autoregressive = self.autoregressive.to(self.device)
|
||||||
|
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
print(f"Loaded autoregressive model")
|
||||||
|
|
||||||
def load_vocoder_model(self, vocoder_model):
|
def load_vocoder_model(self, vocoder_model):
|
||||||
|
if hasattr(self,"vocoder_model_path") and self.vocoder_model_path == vocoder_model:
|
||||||
|
return
|
||||||
|
|
||||||
self.loading = True
|
self.loading = True
|
||||||
|
|
||||||
if hasattr(self, 'vocoder'):
|
if hasattr(self, 'vocoder'):
|
||||||
del self.vocoder
|
del self.vocoder
|
||||||
|
|
||||||
|
@ -358,13 +368,14 @@ class TextToSpeech:
|
||||||
self.vocoder_model_path = 'vocoder.pth'
|
self.vocoder_model_path = 'vocoder.pth'
|
||||||
self.vocoder = UnivNetGenerator().cpu()
|
self.vocoder = UnivNetGenerator().cpu()
|
||||||
|
|
||||||
print(vocoder_model, vocoder_key, self.vocoder_model_path)
|
print(f"Loading vocoder model: {self.vocoder_model_path}")
|
||||||
self.vocoder.load_state_dict(torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[vocoder_key])
|
self.vocoder.load_state_dict(torch.load(get_model_path(self.vocoder_model_path, self.models_dir), map_location=torch.device('cpu'))[vocoder_key])
|
||||||
|
|
||||||
self.vocoder.eval(inference=True)
|
self.vocoder.eval(inference=True)
|
||||||
if self.preloaded_tensors:
|
if self.preloaded_tensors:
|
||||||
self.vocoder = self.vocoder.to(self.device)
|
self.vocoder = self.vocoder.to(self.device)
|
||||||
self.loading = False
|
self.loading = False
|
||||||
|
print(f"Loaded vocoder model")
|
||||||
|
|
||||||
def load_cvvp(self):
|
def load_cvvp(self):
|
||||||
"""Load CVVP model."""
|
"""Load CVVP model."""
|
||||||
|
@ -427,8 +438,7 @@ class TextToSpeech:
|
||||||
|
|
||||||
if slices == 0:
|
if slices == 0:
|
||||||
slices = 1
|
slices = 1
|
||||||
else:
|
elif max_chunk_size is not None and chunk_size > max_chunk_size:
|
||||||
if max_chunk_size is not None and chunk_size > max_chunk_size:
|
|
||||||
slices = 1
|
slices = 1
|
||||||
while int(chunk_size / slices) > max_chunk_size:
|
while int(chunk_size / slices) > max_chunk_size:
|
||||||
slices = slices + 1
|
slices = slices + 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user