From 06bdf72b89af78609059a9e0bc87a1ac88bb1c44 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 3 Mar 2023 13:53:21 +0000 Subject: [PATCH] load the model on CPU because torch doesn't like loading models directly to GPU (it just follows the default vocoder loading behavior) --- tortoise/api.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tortoise/api.py b/tortoise/api.py index 01bcc48..df75a80 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -236,13 +236,6 @@ def classify_audio_clip(clip): results = F.softmax(classifier(clip), dim=-1) return results[0][0] -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - print("Loading '{}'".format(filepath)) - checkpoint_dict = torch.load(filepath, map_location=device) - print("Complete.") - return checkpoint_dict - class TextToSpeech: """ Main entry point into Tortoise. @@ -312,10 +305,9 @@ class TextToSpeech: self.cvvp = None # CVVP model is only loaded if used. if use_bigvgan: - # credit https://github.com/deviandiceto / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52 + # credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52 self.vocoder = BigVGAN().cpu() - state_dict_bigvgan = load_checkpoint(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), self.device) - self.vocoder.load_state_dict(state_dict_bigvgan['generator']) + self.vocoder.load_state_dict(torch.load(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), map_location=torch.device('cpu'))['generator']) else: self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])