unmarried the config.json to the bigvgan by downloading the right one

remotes/1710189933836426429/master
mrq 2023-03-07 13:37:45 +07:00
parent 26133c2031
commit fffea7fc03
3 changed files with 26 additions and 50 deletions

@ -43,8 +43,12 @@ MODELS = {
'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth',
#'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth',
'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth',
'bigvgan_base_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.json',
'bigvgan_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.json',
}
def hash_file(path, algo="md5", buffer_size=0):
@ -361,7 +365,12 @@ class TextToSpeech:
self.vocoder_model_path = 'bigvgan_24khz_100band.pth'
if f'{vocoder_model}.pth' in MODELS:
self.vocoder_model_path = f'{vocoder_model}.pth'
self.vocoder = BigVGAN().cpu()
vocoder_config = 'bigvgan_24khz_100band.json'
if f'{vocoder_model}.json' in MODELS:
vocoder_config = f'{vocoder_model}.json'
vocoder_config = get_model_path(vocoder_config, self.models_dir)
self.vocoder = BigVGAN(config=vocoder_config).cpu()
#elif vocoder_model == "univnet":
else:
vocoder_key = 'model_g'

@ -129,14 +129,27 @@ class AttrDict(dict):
class BigVGAN(nn.Module):
# this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
def __init__(self):
def __init__(self, config=None, data=None):
super(BigVGAN, self).__init__()
"""
with open(os.path.join(os.path.dirname(__file__), 'config.json'), 'r') as f:
data = f.read()
"""
if config and data is None:
with open(config, 'r') as f:
data = f.read()
jsonConfig = json.loads(data)
elif data is not None:
if isinstance(data, str):
jsonConfig = json.loads(data)
else:
jsonConfig = data
else:
raise Exception("no config specified")
global h
jsonConfig = json.loads(data)
h = AttrDict(jsonConfig)
self.mel_channel = h.num_mels

@ -1,46 +0,0 @@
{
"resblock": "1",
"num_gpus": 0,
"batch_size": 32,
"learning_rate": 0.0001,
"adam_b1": 0.8,
"adam_b2": 0.99,
"lr_decay": 0.999,
"seed": 1234,
"upsample_rates": [8,8,2,2],
"upsample_kernel_sizes": [16,16,4,4],
"upsample_initial_channel": 512,
"resblock_kernel_sizes": [3,7,11],
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
"activation": "snakebeta",
"snake_logscale": true,
"discriminator": "mrd",
"resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
"mpd_reshapes": [2, 3, 5, 7, 11],
"use_spectral_norm": false,
"discriminator_channel_mult": 1,
"segment_size": 8192,
"num_mels": 100,
"num_freq": 1025,
"n_fft": 1024,
"hop_size": 256,
"win_size": 1024,
"sampling_rate": 24000,
"fmin": 0,
"fmax": 12000,
"fmax_for_loss": null,
"num_workers": 4,
"dist_config": {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
}