diff --git a/api.py b/api.py index 557d557..6aa94cf 100644 --- a/api.py +++ b/api.py @@ -31,6 +31,7 @@ def download_models(specific_models=None): """ MODELS = { 'autoregressive.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/autoregressive.pth', + 'classifier.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/classifier.pth', 'clvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/clvp.pth', 'cvvp.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/cvvp.pth', 'diffusion_decoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/hf/.models/diffusion_decoder.pth', @@ -153,7 +154,7 @@ def classify_audio_clip(clip): :param clip: torch tensor containing audio waveform data (get it from load_audio) :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. """ - download_models(['classifier']) + download_models(['classifier.pth']) classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, dropout=0, kernel_size=5, distribute_zero_label=False)