diff --git a/tortoise/api.py b/tortoise/api.py index 01bcc48..df75a80 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -236,13 +236,6 @@ def classify_audio_clip(clip): results = F.softmax(classifier(clip), dim=-1) return results[0][0] -def load_checkpoint(filepath, device): - assert os.path.isfile(filepath) - print("Loading '{}'".format(filepath)) - checkpoint_dict = torch.load(filepath, map_location=device) - print("Complete.") - return checkpoint_dict - class TextToSpeech: """ Main entry point into Tortoise. @@ -312,10 +305,9 @@ class TextToSpeech: self.cvvp = None # CVVP model is only loaded if used. if use_bigvgan: - # credit https://github.com/deviandiceto / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52 + # credit to https://github.com/deviandice / https://git.ecker.tech/mrq/ai-voice-cloning/issues/52 self.vocoder = BigVGAN().cpu() - state_dict_bigvgan = load_checkpoint(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), self.device) - self.vocoder.load_state_dict(state_dict_bigvgan['generator']) + self.vocoder.load_state_dict(torch.load(get_model_path('bigvgan_base_24khz_100band.pth', models_dir), map_location=torch.device('cpu'))['generator']) else: self.vocoder = UnivNetGenerator().cpu() self.vocoder.load_state_dict(torch.load(get_model_path('vocoder.pth', models_dir), map_location=torch.device('cpu'))['model_g'])