forked from mrq/ai-voice-cloning
Compare commits
10 Commits
b989123bd4
...
5487c28683
Author | SHA1 | Date | |
---|---|---|---|
|
5487c28683 | ||
9fb4aa7917 | |||
740b5587df | |||
68f4858ce9 | |||
e859a7c01d | |||
e205322c8d | |||
59773a7637 | |||
c956d81baf | |||
534a761e49 | |||
5a41db978e |
7
.gitignore
vendored
7
.gitignore
vendored
|
@ -1,7 +1,8 @@
|
|||
# ignores user files
|
||||
/tortoise-venv/
|
||||
/tortoise/voices/
|
||||
/models/
|
||||
/venv/
|
||||
/voices/*
|
||||
/models/*
|
||||
/training/*
|
||||
/config/*
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
|
|
2
dlas
2
dlas
|
@ -1 +1 @@
|
|||
Subproject commit 0f04206aa20b1ab632c0cbf7bb6a43d5c1fd9eb0
|
||||
Subproject commit 71cc43e65cd47c6704d20c99006a3e78feb2400d
|
|
@ -7,12 +7,12 @@ python -m pip install --upgrade pip
|
|||
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
python -m pip install -r .\dlas\requirements.txt
|
||||
python -m pip install -r .\tortoise-tts\requirements.txt
|
||||
python -m pip install -r .\requirements.txt
|
||||
python -m pip install -e .\tortoise-tts\
|
||||
python -m pip install -r .\requirements.txt
|
||||
|
||||
xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
|
||||
xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y
|
||||
xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y
|
||||
|
||||
deactivate
|
||||
pause
|
||||
pause
|
||||
deactivate
|
|
@ -1,14 +1,17 @@
|
|||
#!/bin/bash
|
||||
# get local dependencies
|
||||
git submodule init
|
||||
git submodule update --remote
|
||||
|
||||
# setup venv
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
python3 -m pip install --upgrade pip
|
||||
python3 -m pip install --upgrade pip # just to be safe
|
||||
# CUDA
|
||||
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
|
||||
python3 -m pip install -r ./dlas/requirements.txt
|
||||
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
python3 -m pip install -e ./tortoise-tts/
|
||||
deactivate
|
||||
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
|
||||
# install requirements
|
||||
python3 -m pip install -r ./dlas/requirements.txt # instal DLAS requirements
|
||||
python3 -m pip install -r ./tortoise-tts/requirements.txt # install TorToiSe requirements
|
||||
python3 -m pip install -e ./tortoise-tts/ # install TorToiSe
|
||||
python3 -m pip install -r ./requirements.txt # install local requirements
|
||||
|
||||
deactivate
|
|
@ -4,10 +4,11 @@ git submodule update --remote
|
|||
python -m venv venv
|
||||
call .\venv\Scripts\activate.bat
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install torch torchvision torchaudio torch-directml==0.1.13.1.dev230119
|
||||
python -m pip install torch torchvision torchaudio torch-directml
|
||||
python -m pip install -r .\dlas\requirements.txt
|
||||
python -m pip install -r .\tortoise-tts\requirements.txt
|
||||
python -m pip install -r .\requirements.txt
|
||||
python -m pip install -e .\tortoise-tts\
|
||||
deactivate
|
||||
pause
|
||||
python -m pip install -r .\requirements.txt
|
||||
|
||||
pause
|
||||
deactivate
|
8
setup-rocm-bnb.sh
Executable file
8
setup-rocm-bnb.sh
Executable file
|
@ -0,0 +1,8 @@
|
|||
#!/bin/bash
|
||||
source ./venv/bin/activate
|
||||
git clone https://git.ecker.tech/mrq/bitsandbytes-rocm
|
||||
cd bitsandbytes-rocm
|
||||
make hip
|
||||
CUDA_VERSION=gfx1030 python setup.py install # assumes you're using a 6XXX series card
|
||||
python3 -m bitsandbytes # to validate it works
|
||||
cd ..
|
|
@ -1,14 +1,19 @@
|
|||
#!/bin/bash
|
||||
# get local dependencies
|
||||
git submodule init
|
||||
git submodule update --remote
|
||||
|
||||
# setup venv
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
python3 -m pip install --upgrade pip
|
||||
python3 -m pip install --upgrade pip # just to be safe
|
||||
# ROCM
|
||||
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 # 5.2 does not work for me desu
|
||||
python3 -m pip install -r ./dlas/requirements.txt
|
||||
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
python3 -m pip install -e ./tortoise-tts/
|
||||
deactivate
|
||||
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 # 5.2 does not work for me desu
|
||||
# install requirements
|
||||
python3 -m pip install -r ./dlas/requirements.txt # instal DLAS requirements
|
||||
python3 -m pip install -r ./tortoise-tts/requirements.txt # install TorToiSe requirements
|
||||
python3 -m pip install -e ./tortoise-tts/ # install TorToiSe
|
||||
python3 -m pip install -r ./requirements.txt # install local requirements
|
||||
# swap to ROCm version of BitsAndBytes
|
||||
pip3 uninstall -y bitsandbytes
|
||||
./setup-rocm-bnb.sh
|
||||
deactivate
|
|
@ -18,9 +18,12 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, help='Rank Number')
|
||||
args = parser.parse_args()
|
||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||
|
||||
os.environ['LOCAL_RANK'] = str(args.local_rank)
|
||||
|
||||
with open(args.opt, 'r') as file:
|
||||
opt_config = yaml.safe_load(file)
|
||||
|
||||
|
@ -71,7 +74,7 @@ def train(yaml, launcher='none'):
|
|||
print('Disabled distributed training.')
|
||||
else:
|
||||
opt['dist'] = True
|
||||
init_dist('nccl')
|
||||
tr.init_dist('nccl')
|
||||
trainer.world_size = torch.distributed.get_world_size()
|
||||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
|
181
src/utils.py
181
src/utils.py
|
@ -97,7 +97,11 @@ def generate(
|
|||
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
|
||||
else:
|
||||
progress(0, desc="Loading voice...")
|
||||
voice_samples, conditioning_latents = load_voice(voice)
|
||||
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
|
||||
else:
|
||||
voice_samples, conditioning_latents = load_voice(voice)
|
||||
|
||||
if voice_samples and len(voice_samples) > 0:
|
||||
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
|
||||
|
@ -107,7 +111,10 @@ def generate(
|
|||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||
|
||||
if voice != "microphone":
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
else:
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
voice_samples = None
|
||||
else:
|
||||
if conditioning_latents is not None:
|
||||
|
@ -371,8 +378,16 @@ def generate(
|
|||
|
||||
|
||||
if voice and voice != "random" and conditioning_latents is not None:
|
||||
with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
|
||||
info['latents'] = base64.b64encode(f.read()).decode("ascii")
|
||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
|
||||
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
|
||||
|
||||
try:
|
||||
with open(latents_path, 'rb') as f:
|
||||
info['latents'] = base64.b64encode(f.read()).decode("ascii")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if args.embed_output_metadata:
|
||||
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
|
||||
|
@ -413,6 +428,32 @@ def cancel_generate():
|
|||
import tortoise.api
|
||||
tortoise.api.STOP_SIGNAL = True
|
||||
|
||||
def hash_file(path, algo="md5", buffer_size=0):
|
||||
import hashlib
|
||||
|
||||
hash = None
|
||||
if algo == "md5":
|
||||
hash = hashlib.md5()
|
||||
elif algo == "sha1":
|
||||
hash = hashlib.sha1()
|
||||
else:
|
||||
raise Exception(f'Unknown hash algorithm specified: {algo}')
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise Exception(f'Path not found: {path}')
|
||||
|
||||
with open(path, 'rb') as f:
|
||||
if buffer_size > 0:
|
||||
while True:
|
||||
data = f.read(buffer_size)
|
||||
if not data:
|
||||
break
|
||||
hash.update(data)
|
||||
else:
|
||||
hash.update(f.read())
|
||||
|
||||
return "{0}".format(hash.hexdigest())
|
||||
|
||||
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
|
||||
global tts
|
||||
global args
|
||||
|
@ -435,15 +476,16 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
|
|||
if len(conditioning_latents) == 4:
|
||||
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
|
||||
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
if hasattr(tts, 'autoregressive_model_hash'):
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
|
||||
else:
|
||||
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
|
||||
|
||||
return voice
|
||||
|
||||
# superfluous, but it cleans up some things
|
||||
class TrainingState():
|
||||
def __init__(self, config_path, keep_x_past_datasets=0):
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
|
||||
|
||||
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
|
||||
# parse config to get its iteration
|
||||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
@ -487,17 +529,22 @@ class TrainingState():
|
|||
self.eta = "?"
|
||||
self.eta_hhmmss = "?"
|
||||
|
||||
self.last_info_check_at = 0
|
||||
self.losses = []
|
||||
|
||||
self.load_losses()
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
self.spawn_process()
|
||||
if keep_x_past_datasets > 0:
|
||||
self.cleanup_old(keep=keep_x_past_datasets)
|
||||
if start:
|
||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||
|
||||
def spawn_process(self, config_path, gpus=1):
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
|
||||
|
||||
def spawn_process(self):
|
||||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
def load_losses(self):
|
||||
def load_losses(self, update=False):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
|
||||
return
|
||||
try:
|
||||
|
@ -506,18 +553,26 @@ class TrainingState():
|
|||
except Exception as e:
|
||||
use_tensorboard = False
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
||||
infos = {}
|
||||
highest_step = self.last_info_check_at
|
||||
|
||||
if use_tensorboard:
|
||||
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
|
||||
infos = {}
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
for log in logs:
|
||||
try:
|
||||
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
|
||||
ea.Reload()
|
||||
|
||||
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
|
||||
for key in keys:
|
||||
scalar = ea.Scalars(key)
|
||||
for s in scalar:
|
||||
if update and s.step <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
|
||||
except Exception as e:
|
||||
print("Failed to parse event log:", log)
|
||||
|
@ -525,7 +580,9 @@ class TrainingState():
|
|||
|
||||
else:
|
||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||
infos = {}
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
@ -546,9 +603,13 @@ class TrainingState():
|
|||
|
||||
for k in infos:
|
||||
if 'loss_gpt_total' in infos[k]:
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" })
|
||||
for key in keys:
|
||||
if update and int(k) <= self.last_info_check_at:
|
||||
continue
|
||||
highest_step = max( highest_step, s.step )
|
||||
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
|
||||
|
||||
self.last_info_check_at = highest_step
|
||||
|
||||
def cleanup_old(self, keep=2):
|
||||
if keep <= 0:
|
||||
|
@ -581,7 +642,8 @@ class TrainingState():
|
|||
if line.find('Start training from epoch') >= 0:
|
||||
self.epoch_time_start = time.time()
|
||||
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
|
||||
|
||||
should_return = True
|
||||
|
||||
match = re.findall(r'epoch: ([\d,]+)', line)
|
||||
if match and len(match) > 0:
|
||||
self.epoch = int(match[0].replace(",", ""))
|
||||
|
@ -662,12 +724,15 @@ class TrainingState():
|
|||
|
||||
if 'loss_gpt_total' in self.info:
|
||||
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
|
||||
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" })
|
||||
|
||||
"""
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
|
||||
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
|
||||
"""
|
||||
should_return = True
|
||||
|
||||
self.load_losses(update=True)
|
||||
|
||||
elif line.find('Saving models and training states') >= 0:
|
||||
self.checkpoint = self.checkpoint + 1
|
||||
|
||||
|
@ -688,7 +753,7 @@ class TrainingState():
|
|||
if should_return:
|
||||
return "".join(self.buffer)
|
||||
|
||||
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
if training_state and training_state.process:
|
||||
return "Training already in progress"
|
||||
|
@ -700,7 +765,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets
|
|||
unload_whisper()
|
||||
unload_voicefixer()
|
||||
|
||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
|
||||
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
|
||||
|
||||
for line in iter(training_state.process.stdout.readline, ""):
|
||||
|
||||
|
@ -723,11 +788,21 @@ def get_training_losses():
|
|||
return
|
||||
return pd.DataFrame(training_state.losses)
|
||||
|
||||
def update_training_dataplot():
|
||||
def update_training_dataplot(config_path=None):
|
||||
global training_state
|
||||
if not training_state or not training_state.losses:
|
||||
return
|
||||
return gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||
update = None
|
||||
|
||||
if not training_state:
|
||||
if config_path:
|
||||
training_state = TrainingState(config_path=config_path, start=False)
|
||||
if training_state.losses:
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||
del training_state
|
||||
training_state = None
|
||||
elif training_state.losses:
|
||||
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
|
||||
|
||||
return update
|
||||
|
||||
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
|
||||
global training_state
|
||||
|
@ -823,8 +898,11 @@ def prepare_dataset( files, outdir, language=None, progress=None ):
|
|||
|
||||
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
|
||||
|
||||
transcription.append(f"{sliced_name}|{segment['text'].strip()}")
|
||||
idx = idx + 1
|
||||
line = f"{sliced_name}|{segment['text'].strip()}"
|
||||
transcription.append(line)
|
||||
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
|
||||
f.write(f'{line}\n')
|
||||
|
||||
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
|
||||
f.write(json.dumps(results, indent='\t'))
|
||||
|
@ -1035,32 +1113,6 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
|
|||
res = res + ["random", "microphone"]
|
||||
return res
|
||||
|
||||
def hash_file(path, algo="md5", buffer_size=0):
|
||||
import hashlib
|
||||
|
||||
hash = None
|
||||
if algo == "md5":
|
||||
hash = hashlib.md5()
|
||||
elif algo == "sha1":
|
||||
hash = hashlib.sha1()
|
||||
else:
|
||||
raise Exception(f'Unknown hash algorithm specified: {algo}')
|
||||
|
||||
if not os.path.exists(path):
|
||||
raise Exception(f'Path not found: {path}')
|
||||
|
||||
with open(path, 'rb') as f:
|
||||
if buffer_size > 0:
|
||||
while True:
|
||||
data = f.read(buffer_size)
|
||||
if not data:
|
||||
break
|
||||
hash.update(data)
|
||||
else:
|
||||
hash.update(f.read())
|
||||
|
||||
return "{0}".format(hash.hexdigest())
|
||||
|
||||
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
base = [get_model_path('autoregressive.pth')]
|
||||
|
@ -1190,6 +1242,7 @@ def setup_args():
|
|||
'defer-tts-load': False,
|
||||
'device-override': None,
|
||||
'prune-nonfinal-outputs': True,
|
||||
'use-bigvgan-vocoder': True,
|
||||
'concurrency-count': 2,
|
||||
'output-sample-rate': 44100,
|
||||
'output-volume': 1,
|
||||
|
@ -1225,6 +1278,7 @@ def setup_args():
|
|||
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
|
||||
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
|
||||
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
|
||||
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true', help="Uses BigVGAN in place of the default vocoder")
|
||||
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
|
||||
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
|
||||
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
|
||||
|
@ -1246,12 +1300,13 @@ def setup_args():
|
|||
if not args.device_override:
|
||||
set_device_name(args.device_override)
|
||||
|
||||
|
||||
args.listen_host = None
|
||||
args.listen_port = None
|
||||
args.listen_path = None
|
||||
if args.listen:
|
||||
try:
|
||||
match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0]
|
||||
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
|
||||
|
||||
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
|
||||
args.listen_port = match[1] if match[1] != "" else None
|
||||
|
@ -1264,7 +1319,7 @@ def setup_args():
|
|||
|
||||
return args
|
||||
|
||||
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
|
||||
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
|
||||
global args
|
||||
|
||||
args.listen = listen
|
||||
|
@ -1275,6 +1330,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
|||
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
|
||||
args.defer_tts_load = defer_tts_load
|
||||
args.prune_nonfinal_outputs = prune_nonfinal_outputs
|
||||
args.use_bigvgan_vocoder = use_bigvgan_vocoder
|
||||
args.device_override = device_override
|
||||
args.sample_batch_size = sample_batch_size
|
||||
args.embed_output_metadata = embed_output_metadata
|
||||
|
@ -1297,7 +1353,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
|
|||
def save_args_settings():
|
||||
global args
|
||||
settings = {
|
||||
'listen': None if args.listen else args.listen,
|
||||
'listen': None if not args.listen else args.listen,
|
||||
'share': args.share,
|
||||
'low-vram':args.low_vram,
|
||||
'check-for-updates':args.check_for_updates,
|
||||
|
@ -1305,6 +1361,7 @@ def save_args_settings():
|
|||
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
|
||||
'defer-tts-load': args.defer_tts_load,
|
||||
'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
|
||||
'use-bigvgan-vocoder': args.use_bigvgan_vocoder,
|
||||
'device-override': args.device_override,
|
||||
'sample-batch-size': args.sample_batch_size,
|
||||
'embed-output-metadata': args.embed_output_metadata,
|
||||
|
@ -1419,7 +1476,7 @@ def load_tts( restart=False, model=None ):
|
|||
|
||||
tts_loading = True
|
||||
try:
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, use_bigvgan=args.use_bigvgan_vocoder)
|
||||
except Exception as e:
|
||||
tts = TextToSpeech(minor_optimizations=not args.low_vram)
|
||||
load_autoregressive_model(args.autoregressive_model)
|
||||
|
@ -1533,7 +1590,7 @@ def load_whisper_model(name=None, progress=None, language=b'en'):
|
|||
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
|
||||
if args.whisper_cpp:
|
||||
from whispercpp import Whisper
|
||||
whisper_model = Whisper(name, models_dir='./models/', language=language)
|
||||
whisper_model = Whisper(name, models_dir='./models/', language=language.encode('ascii'))
|
||||
else:
|
||||
import whisper
|
||||
whisper_model = whisper.load_model(args.whisper_model)
|
||||
|
|
39
src/webui.py
39
src/webui.py
|
@ -527,17 +527,9 @@ def setup_gradio():
|
|||
with gr.Row():
|
||||
with gr.Column():
|
||||
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
|
||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
reconnect_training_button = gr.Button(value="Reconnect")
|
||||
with gr.Column():
|
||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
|
||||
refresh_configs = gr.Button(value="Refresh Configurations")
|
||||
|
||||
training_loss_graph = gr.LinePlot(label="Training Metrics",
|
||||
x="step",
|
||||
y="value",
|
||||
|
@ -545,8 +537,20 @@ def setup_gradio():
|
|||
color="type",
|
||||
tooltip=['step', 'value', 'type'],
|
||||
width=600,
|
||||
height=350
|
||||
height=350,
|
||||
)
|
||||
view_losses = gr.Button(value="View Losses")
|
||||
|
||||
with gr.Column():
|
||||
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
|
||||
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
|
||||
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
|
||||
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
|
||||
training_gpu_count = gr.Number(label="GPUs", value=1)
|
||||
with gr.Row():
|
||||
start_training_button = gr.Button(value="Train")
|
||||
stop_training_button = gr.Button(value="Stop")
|
||||
reconnect_training_button = gr.Button(value="Reconnect")
|
||||
with gr.Tab("Settings"):
|
||||
with gr.Row():
|
||||
exec_inputs = []
|
||||
|
@ -564,6 +568,7 @@ def setup_gradio():
|
|||
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
|
||||
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
|
||||
gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs),
|
||||
gr.Checkbox(label="Use BigVGAN Vocoder", value=args.use_bigvgan_vocoder),
|
||||
gr.Textbox(label="Device Override", value=args.device_override),
|
||||
]
|
||||
with gr.Column():
|
||||
|
@ -748,6 +753,7 @@ def setup_gradio():
|
|||
inputs=[
|
||||
training_configs,
|
||||
verbose_training,
|
||||
training_gpu_count,
|
||||
training_buffer_size,
|
||||
training_keep_x_past_datasets,
|
||||
],
|
||||
|
@ -763,6 +769,17 @@ def setup_gradio():
|
|||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
view_losses.click(
|
||||
fn=update_training_dataplot,
|
||||
inputs=[
|
||||
training_configs
|
||||
],
|
||||
outputs=[
|
||||
training_loss_graph,
|
||||
],
|
||||
)
|
||||
|
||||
stop_training_button.click(stop_training,
|
||||
inputs=None,
|
||||
outputs=training_output #console_output
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
call .\venv\Scripts\activate.bat
|
||||
set PATH=.\bin\;%PATH%
|
||||
python .\src\main.py %*
|
||||
deactivate
|
||||
pause
|
|
@ -1 +1 @@
|
|||
Subproject commit 7cc0250a1a559da90965812fdefcba0d54a59c41
|
||||
Subproject commit aca32a71f798ebd8487c113d41d1b4e9ee15c315
|
|
@ -1,4 +1,4 @@
|
|||
call .\venv\Scripts\activate.bat
|
||||
python ./src/train.py -opt "%1"
|
||||
deactivate
|
||||
pause
|
||||
pause
|
||||
deactivate
|
11
train.sh
11
train.sh
|
@ -1,4 +1,13 @@
|
|||
#!/bin/bash
|
||||
source ./venv/bin/activate
|
||||
python3 ./src/train.py -opt "$1"
|
||||
|
||||
GPUS=$1
|
||||
CONFIG=$2
|
||||
PORT=1234
|
||||
|
||||
if (( $GPUS > 1 )); then
|
||||
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
|
||||
else
|
||||
python3 ./src/train.py -opt "$CONFIG"
|
||||
fi
|
||||
deactivate
|
||||
|
|
|
@ -1,3 +1,15 @@
|
|||
git fetch --all
|
||||
git reset --hard origin/master
|
||||
call .\update.bat
|
||||
call .\update.bat
|
||||
|
||||
python -m venv venv
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -U -r .\dlas\requirements.txt
|
||||
python -m pip install -U -r .\tortoise-tts\requirements.txt
|
||||
python -m pip install -U -e .\tortoise-tts
|
||||
python -m pip install -U -r .\requirements.txt
|
||||
|
||||
pause
|
||||
deactivate
|
|
@ -1,4 +1,17 @@
|
|||
#!/bin/bash
|
||||
git fetch --all
|
||||
git reset --hard origin/master
|
||||
./update.sh
|
||||
|
||||
./update.sh
|
||||
|
||||
# force install requirements
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
|
||||
python3 -m pip install --upgrade pip
|
||||
python3 -m pip install -r ./dlas/requirements.txt
|
||||
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
||||
python3 -m pip install -e ./tortoise-tts
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
|
||||
deactivate
|
14
update.bat
14
update.bat
|
@ -1,14 +1,2 @@
|
|||
git pull
|
||||
git submodule update --remote
|
||||
|
||||
python -m venv venv
|
||||
call .\venv\Scripts\activate.bat
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r .\dlas\requirements.txt
|
||||
python -m pip install -r .\tortoise-tts\requirements.txt
|
||||
python -m pip install -e .\tortoise-tts
|
||||
python -m pip install -r .\requirements.txt
|
||||
|
||||
deactivate
|
||||
pause
|
||||
git submodule update --remote
|
10
update.sh
10
update.sh
|
@ -2,14 +2,6 @@
|
|||
git pull
|
||||
git submodule update --remote
|
||||
|
||||
python3 -m venv venv
|
||||
source ./venv/bin/activate
|
||||
|
||||
python3 -m pip install --upgrade pip
|
||||
python3 -m pip install -r ./dlas/requirements.txt
|
||||
python3 -m pip install -r ./tortoise-tts/requirements.txt
|
||||
python3 -m pip install -e ./tortoise-tts
|
||||
python3 -m pip install -r ./requirements.txt
|
||||
|
||||
if python -m pip show whispercpp &>/dev/null; then python -m pip install -U git+https://git.ecker.tech/lightmare/whispercpp.py; fi
|
||||
|
||||
deactivate
|
Loading…
Reference in New Issue
Block a user