From fffea7fc038cc20945aa2208faf91c5434719b69 Mon Sep 17 00:00:00 2001
From: mrq <barry.quiggles@protonmail.com>
Date: Tue, 7 Mar 2023 13:37:45 +0000
Subject: [PATCH] unmarried the config.json to the bigvgan by downloading the
 right one

---
 tortoise/api.py             | 13 +++++++++--
 tortoise/models/bigvgan.py  | 17 ++++++++++++--
 tortoise/models/config.json | 46 -------------------------------------
 3 files changed, 26 insertions(+), 50 deletions(-)
 delete mode 100644 tortoise/models/config.json

diff --git a/tortoise/api.py b/tortoise/api.py
index 9765a87..b663fa6 100755
--- a/tortoise/api.py
+++ b/tortoise/api.py
@@ -43,8 +43,12 @@ MODELS = {
     'vocoder.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/vocoder.pth',
     'rlg_auto.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_auto.pth',
     'rlg_diffuser.pth': 'https://huggingface.co/jbetker/tortoise-tts-v2/resolve/main/.models/rlg_diffuser.pth',
+    
     'bigvgan_base_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.pth',
-    #'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth',
+    'bigvgan_24khz_100band.pth': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.pth',
+
+    'bigvgan_base_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_base_24khz_100band.json',
+    'bigvgan_24khz_100band.json': 'https://huggingface.co/ecker/tortoise-tts-models/resolve/main/models/bigvgan_24khz_100band.json',
 }
 
 def hash_file(path, algo="md5", buffer_size=0):
@@ -361,7 +365,12 @@ class TextToSpeech:
             self.vocoder_model_path = 'bigvgan_24khz_100band.pth'
             if f'{vocoder_model}.pth' in MODELS:
                 self.vocoder_model_path = f'{vocoder_model}.pth'
-            self.vocoder = BigVGAN().cpu()
+            vocoder_config = 'bigvgan_24khz_100band.json'
+            if f'{vocoder_model}.json' in MODELS:
+                vocoder_config = f'{vocoder_model}.json'
+            vocoder_config = get_model_path(vocoder_config, self.models_dir)
+
+            self.vocoder = BigVGAN(config=vocoder_config).cpu()
         #elif vocoder_model == "univnet":
         else:
             vocoder_key = 'model_g'
diff --git a/tortoise/models/bigvgan.py b/tortoise/models/bigvgan.py
index a7f90bd..29d7df6 100644
--- a/tortoise/models/bigvgan.py
+++ b/tortoise/models/bigvgan.py
@@ -129,14 +129,27 @@ class AttrDict(dict):
 
 class BigVGAN(nn.Module):
     # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
-    def __init__(self):
+    def __init__(self, config=None, data=None):
         super(BigVGAN, self).__init__()
 
+        """
         with open(os.path.join(os.path.dirname(__file__), 'config.json'), 'r') as f:
             data = f.read()
+        """
+        if config and data is None:
+            with open(config, 'r') as f:
+                data = f.read()
+            jsonConfig = json.loads(data)
+        elif data is not None:
+            if isinstance(data, str):
+                jsonConfig = json.loads(data)
+            else:
+                jsonConfig = data
+        else:
+            raise Exception("no config specified")
+
 
         global h
-        jsonConfig = json.loads(data)
         h = AttrDict(jsonConfig)
 
         self.mel_channel = h.num_mels
diff --git a/tortoise/models/config.json b/tortoise/models/config.json
deleted file mode 100644
index d3a8d3a..0000000
--- a/tortoise/models/config.json
+++ /dev/null
@@ -1,46 +0,0 @@
-{
-    "resblock": "1",
-    "num_gpus": 0,
-    "batch_size": 32,
-    "learning_rate": 0.0001,
-    "adam_b1": 0.8,
-    "adam_b2": 0.99,
-    "lr_decay": 0.999,
-    "seed": 1234,
-
-    "upsample_rates": [8,8,2,2],
-    "upsample_kernel_sizes": [16,16,4,4],
-    "upsample_initial_channel": 512,
-    "resblock_kernel_sizes": [3,7,11],
-    "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
-
-    "activation": "snakebeta",
-    "snake_logscale": true,
-
-    "discriminator": "mrd",
-    "resolutions": [[1024, 120, 600], [2048, 240, 1200], [512, 50, 240]],
-    "mpd_reshapes": [2, 3, 5, 7, 11],
-    "use_spectral_norm": false,
-    "discriminator_channel_mult": 1,
-
-    "segment_size": 8192,
-    "num_mels": 100,
-    "num_freq": 1025,
-    "n_fft": 1024,
-    "hop_size": 256,
-    "win_size": 1024,
-
-    "sampling_rate": 24000,
-
-    "fmin": 0,
-    "fmax": 12000,
-    "fmax_for_loss": null,
-
-    "num_workers": 4,
-
-    "dist_config": {
-        "dist_backend": "nccl",
-        "dist_url": "tcp://localhost:54321",
-        "world_size": 1
-    }
-}