forked from mrq/tortoise-tts
load the model on CPU because torch doesn't like loading models directly to GPU (it just follows the default vocoder loading behavior)
This commit is contained in:
parent
2ba0e056cd
commit
06bdf72b89
|
@ -236,13 +236,6 @@ def classify_audio_clip(clip):
|
||||||
results = F.softmax(classifier(clip), dim=-1)
|
results = F.softmax(classifier(clip), dim=-1)
|
||||||
return results[0][0]
|
return results[0][0]
|
||||||
|
|
||||||
def load_checkpoint(filepath, device):
|
|
||||||
assert os.path.isfile(filepath)
|
|
||||||
print("Loading '{}'".format(filepath))
|
|
||||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
|
||||||
print("Complete.")
|
|
||||||
return checkpoint_dict
|
|
||||||
|
|
||||||
class TextToSpeech:
|
class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
|
@ -312,10 +305,9 @@ class TextToSpeech:
|
||||||
self.cvvp = None # CVVP model is only loaded if used.
|
self.cvvp = None # CVVP model is only loaded if used.
|
||||||
|
|
||||||
if use_bigvgan:
|
if use_bigvgan:
|
||||||
# credit https://github.com/deviandiceto / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52
|
# credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52
|
||||||
self.vocoder = BigVGAN().cpu()
|
self.vocoder = BigVGAN().cpu()
|
||||||
state_dict_bigvgan = load_checkpoint(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), self.device)
|
self.vocoder.load_state_dict(torch.load(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), map_location=torch.device('cpu'))['generator'])
|
||||||
self.vocoder.load_state_dict(state_dict_bigvgan['generator'])
|
|
||||||
else:
|
else:
|
||||||
self.vocoder = UnivNetGenerator().cpu()
|
self.vocoder = UnivNetGenerator().cpu()
|
||||||
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
|
self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user