diff --git a/tortoise/api.py b/tortoise/api.py index 9765a87..b663fa6 100755 --- a/tortoise/api.py +++ b/tortoise/api.py @@ -43,8 +43,12 @@ MODELS = { 'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth', 'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth', 'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth', + 'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth', - #'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth', + 'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth', + + 'bigvgan_base_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.json', + 'bigvgan_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.json', } def hash_file(path, algo="md5", buffer_size=0): @@ -361,7 +365,12 @@ class TextToSpeech: self.vocoder_model_path = 'bigvgan_24khz_100band.pth' if f'{vocoder_model}.pth' in MODELS: self.vocoder_model_path = f'{vocoder_model}.pth' - self.vocoder = BigVGAN().cpu() + vocoder_config = 'bigvgan_24khz_100band.json' + if f'{vocoder_model}.json' in MODELS: + vocoder_config = f'{vocoder_model}.json' + vocoder_config = get_model_path(vocoder_config, self.models_dir) + + self.vocoder = BigVGAN(config=vocoder_config).cpu() #elif vocoder_model == "univnet": else: vocoder_key = 'model_g' diff --git a/tortoise/models/bigvgan.py b/tortoise/models/bigvgan.py index a7f90bd..29d7df6 100644 --- a/tortoise/models/bigvgan.py +++ b/tortoise/models/bigvgan.py @@ -129,14 +129,27 @@ class AttrDict(dict): class BigVGAN(nn.Module): # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks. - def __init__(self): + def __init__(self, config=None, data=None): super(BigVGAN, self).__init__() + """ with open(os.path.join(os.path.dirname(__file__), 'config.json'), 'r') as f: data = f.read() + """ + if config and data is None: + with open(config, 'r') as f: + data = f.read() + jsonConfig = json.loads(data) + elif data is not None: + if isinstance(data, str): + jsonConfig = json.loads(data) + else: + jsonConfig = data + else: + raise Exception("no config specified") + global h - jsonConfig = json.loads(data) h = AttrDict(jsonConfig) self.mel_channel = h.num_mels diff --git a/tortoise/models/config.json b/tortoise/models/config.json deleted file mode 100644 index d3a8d3a..0000000 --- a/tortoise/models/config.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "resblock": "1", - "num_gpus": 0, - "batch_size": 32, - "learning_rate": 0.0001, - "adam_b1": 0.8, - "adam_b2": 0.99, - "lr_decay": 0.999, - "seed": 1234, - - "upsample_rates": [8,8,2,2], - "upsample_kernel_sizes": [16,16,4,4], - "upsample_initial_channel": 512, - "resblock_kernel_sizes": [3,7,11], - "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]], - - "activation": "snakebeta", - "snake_logscale": true, - - "discriminator": "mrd", - "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]], - "mpd_reshapes": [2, 3, 5, 7, 11], - "use_spectral_norm": false, - "discriminator_channel_mult": 1, - - "segment_size": 8192, - "num_mels": 100, - "num_freq": 1025, - "n_fft": 1024, - "hop_size": 256, - "win_size": 1024, - - "sampling_rate": 24000, - - "fmin": 0, - "fmax": 12000, - "fmax_for_loss": null, - - "num_workers": 4, - - "dist_config": { - "dist_backend": "nccl", - "dist_url": "tcp://localhost:54321", - "world_size": 1 - } -}