Compare commits

...

10 Commits

18 changed files with 238 additions and 130 deletions

7
.gitignore vendored
View File

@ -1,7 +1,8 @@
# ignores user files
/tortoise-venv/
/tortoise/voices/
/models/
/venv/
/voices/*
/models/*
/training/*
/config/*
# Byte-compiled / optimized / DLL files

2
dlas

@ -1 +1 @@
Subproject commit 0f04206aa20b1ab632c0cbf7bb6a43d5c1fd9eb0
Subproject commit 71cc43e65cd47c6704d20c99006a3e78feb2400d

View File

@ -7,12 +7,12 @@ python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
python -m pip install -r .\dlas\requirements.txt
python -m pip install -r .\tortoise-tts\requirements.txt
python -m pip install -r .\requirements.txt
python -m pip install -e .\tortoise-tts\
python -m pip install -r .\requirements.txt
xcopy .\dlas\bitsandbytes_windows\* .\venv\Lib\site-packages\bitsandbytes\. /Y
xcopy .\dlas\bitsandbytes_windows\cuda_setup\* .\venv\Lib\site-packages\bitsandbytes\cuda_setup\. /Y
xcopy .\dlas\bitsandbytes_windows\nn\* .\venv\Lib\site-packages\bitsandbytes\nn\. /Y
deactivate
pause
deactivate

View File

@ -1,14 +1,17 @@
#!/bin/bash
# get local dependencies
git submodule init
git submodule update --remote
# setup venv
python3 -m venv venv
source ./venv/bin/activate
python3 -m pip install --upgrade pip
python3 -m pip install --upgrade pip # just to be safe
# CUDA
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
python3 -m pip install -r ./dlas/requirements.txt
python3 -m pip install -r ./tortoise-tts/requirements.txt
python3 -m pip install -r ./requirements.txt
python3 -m pip install -e ./tortoise-tts/
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
# install requirements
python3 -m pip install -r ./dlas/requirements.txt # instal DLAS requirements
python3 -m pip install -r ./tortoise-tts/requirements.txt # install TorToiSe requirements
python3 -m pip install -e ./tortoise-tts/ # install TorToiSe
python3 -m pip install -r ./requirements.txt # install local requirements
deactivate

View File

@ -4,10 +4,11 @@ git submodule update --remote
python -m venv venv
call .\venv\Scripts\activate.bat
python -m pip install --upgrade pip
python -m pip install torch torchvision torchaudio torch-directml==0.1.13.1.dev230119
python -m pip install torch torchvision torchaudio torch-directml
python -m pip install -r .\dlas\requirements.txt
python -m pip install -r .\tortoise-tts\requirements.txt
python -m pip install -r .\requirements.txt
python -m pip install -e .\tortoise-tts\
deactivate
python -m pip install -r .\requirements.txt
pause
deactivate

8
setup-rocm-bnb.sh Executable file
View File

@ -0,0 +1,8 @@
#!/bin/bash
source ./venv/bin/activate
git clone https://git.ecker.tech/mrq/bitsandbytes-rocm
cd bitsandbytes-rocm
make hip
CUDA_VERSION=gfx1030 python setup.py install # assumes you're using a 6XXX series card
python3 -m bitsandbytes # to validate it works
cd ..

View File

@ -1,14 +1,19 @@
#!/bin/bash
# get local dependencies
git submodule init
git submodule update --remote
# setup venv
python3 -m venv venv
source ./venv/bin/activate
python3 -m pip install --upgrade pip
python3 -m pip install --upgrade pip # just to be safe
# ROCM
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 # 5.2 does not work for me desu
python3 -m pip install -r ./dlas/requirements.txt
python3 -m pip install -r ./tortoise-tts/requirements.txt
python3 -m pip install -r ./requirements.txt
python3 -m pip install -e ./tortoise-tts/
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.1.1 # 5.2 does not work for me desu
# install requirements
python3 -m pip install -r ./dlas/requirements.txt # instal DLAS requirements
python3 -m pip install -r ./tortoise-tts/requirements.txt # install TorToiSe requirements
python3 -m pip install -e ./tortoise-tts/ # install TorToiSe
python3 -m pip install -r ./requirements.txt # install local requirements
# swap to ROCm version of BitsAndBytes
pip3 uninstall -y bitsandbytes
./setup-rocm-bnb.sh
deactivate

View File

@ -18,9 +18,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, help='Rank Number')
args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting
os.environ['LOCAL_RANK'] = str(args.local_rank)
with open(args.opt, 'r') as file:
opt_config = yaml.safe_load(file)
@ -71,7 +74,7 @@ def train(yaml, launcher='none'):
print('Disabled distributed training.')
else:
opt['dist'] = True
init_dist('nccl')
tr.init_dist('nccl')
trainer.world_size = torch.distributed.get_world_size()
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())

View File

@ -97,7 +97,11 @@ def generate(
voice_samples, conditioning_latents = None, tts.get_random_conditioning_latents()
else:
progress(0, desc="Loading voice...")
voice_samples, conditioning_latents = load_voice(voice)
# nasty check for users that, for whatever reason, updated the web UI but not mrq/tortoise-tts
if hasattr(tts, 'autoregressive_model_hash'):
voice_samples, conditioning_latents = load_voice(voice, model_hash=tts.autoregressive_model_hash)
else:
voice_samples, conditioning_latents = load_voice(voice)
if voice_samples and len(voice_samples) > 0:
sample_voice = torch.cat(voice_samples, dim=-1).squeeze().cpu()
@ -107,7 +111,10 @@ def generate(
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
if voice != "microphone":
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
voice_samples = None
else:
if conditioning_latents is not None:
@ -371,8 +378,16 @@ def generate(
if voice and voice != "random" and conditioning_latents is not None:
with open(f'{get_voice_dir()}/{voice}/cond_latents.pth', 'rb') as f:
info['latents'] = base64.b64encode(f.read()).decode("ascii")
latents_path = f'{get_voice_dir()}/{voice}/cond_latents.pth'
if hasattr(tts, 'autoregressive_model_hash'):
latents_path = f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth'
try:
with open(latents_path, 'rb') as f:
info['latents'] = base64.b64encode(f.read()).decode("ascii")
except Exception as e:
pass
if args.embed_output_metadata:
for name in progress.tqdm(audio_cache, desc="Embedding metadata..."):
@ -413,6 +428,32 @@ def cancel_generate():
import tortoise.api
tortoise.api.STOP_SIGNAL = True
def hash_file(path, algo="md5", buffer_size=0):
import hashlib
hash = None
if algo == "md5":
hash = hashlib.md5()
elif algo == "sha1":
hash = hashlib.sha1()
else:
raise Exception(f'Unknown hash algorithm specified: {algo}')
if not os.path.exists(path):
raise Exception(f'Path not found: {path}')
with open(path, 'rb') as f:
if buffer_size > 0:
while True:
data = f.read(buffer_size)
if not data:
break
hash.update(data)
else:
hash.update(f.read())
return "{0}".format(hash.hexdigest())
def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm=True)):
global tts
global args
@ -435,15 +476,16 @@ def compute_latents(voice, voice_latents_chunks, progress=gr.Progress(track_tqdm
if len(conditioning_latents) == 4:
conditioning_latents = (conditioning_latents[0], conditioning_latents[1], conditioning_latents[2], None)
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
if hasattr(tts, 'autoregressive_model_hash'):
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents_{tts.autoregressive_model_hash[:8]}.pth')
else:
torch.save(conditioning_latents, f'{get_voice_dir()}/{voice}/cond_latents.pth')
return voice
# superfluous, but it cleans up some things
class TrainingState():
def __init__(self, config_path, keep_x_past_datasets=0):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', config_path]
def __init__(self, config_path, keep_x_past_datasets=0, start=True, gpus=1):
# parse config to get its iteration
with open(config_path, 'r') as file:
self.config = yaml.safe_load(file)
@ -487,17 +529,22 @@ class TrainingState():
self.eta = "?"
self.eta_hhmmss = "?"
self.last_info_check_at = 0
self.losses = []
self.load_losses()
self.cleanup_old(keep=keep_x_past_datasets)
self.spawn_process()
if keep_x_past_datasets > 0:
self.cleanup_old(keep=keep_x_past_datasets)
if start:
self.spawn_process(config_path=config_path, gpus=gpus)
def spawn_process(self, config_path, gpus=1):
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['bash', './train.sh', str(int(gpus)), config_path]
def spawn_process(self):
print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def load_losses(self):
def load_losses(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/tb_logger/'):
return
try:
@ -506,18 +553,26 @@ class TrainingState():
except Exception as e:
use_tensorboard = False
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
infos = {}
highest_step = self.last_info_check_at
if use_tensorboard:
logs = sorted([f'{self.dataset_dir}/tb_logger/{d}' for d in os.listdir(f'{self.dataset_dir}/tb_logger/') if d[:6] == "events" ])
infos = {}
if update:
logs = [logs[-1]]
for log in logs:
try:
ea = event_accumulator.EventAccumulator(log, size_guidance={event_accumulator.SCALARS: 0})
ea.Reload()
keys = ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']
for key in keys:
scalar = ea.Scalars(key)
for s in scalar:
if update and s.step <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.losses.append( { "step": s.step, "value": s.value, "type": key } )
except Exception as e:
print("Failed to parse event log:", log)
@ -525,7 +580,9 @@ class TrainingState():
else:
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
infos = {}
if update:
logs = [logs[-1]]
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
@ -546,9 +603,13 @@ class TrainingState():
for k in infos:
if 'loss_gpt_total' in infos[k]:
self.losses.append({ "step": int(k), "value": infos[k]['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "step": int(k), "value": infos[k]['loss_gpt_total'], "type": "gpt_total" })
for key in keys:
if update and int(k) <= self.last_info_check_at:
continue
highest_step = max( highest_step, s.step )
self.losses.append({ "step": int(k), "value": infos[k][key], "type": key })
self.last_info_check_at = highest_step
def cleanup_old(self, keep=2):
if keep <= 0:
@ -581,6 +642,7 @@ class TrainingState():
if line.find('Start training from epoch') >= 0:
self.epoch_time_start = time.time()
self.training_started = True # could just leverage the above variable, but this is python, and there's no point in these aggressive microoptimizations
should_return = True
match = re.findall(r'epoch: ([\d,]+)', line)
if match and len(match) > 0:
@ -662,12 +724,15 @@ class TrainingState():
if 'loss_gpt_total' in self.info:
self.status = f"Total loss at epoch {self.epoch}: {self.info['loss_gpt_total']}"
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "text_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "mel_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "gpt_total" })
"""
self.losses.append({ "step": self.it, "value": self.info['loss_text_ce'], "type": "loss_text_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_mel_ce'], "type": "loss_mel_ce" })
self.losses.append({ "step": self.it, "value": self.info['loss_gpt_total'], "type": "loss_gpt_total" })
"""
should_return = True
self.load_losses(update=True)
elif line.find('Saving models and training states') >= 0:
self.checkpoint = self.checkpoint + 1
@ -688,7 +753,7 @@ class TrainingState():
if should_return:
return "".join(self.buffer)
def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
def run_training(config_path, verbose=False, gpus=1, buffer_size=8, keep_x_past_datasets=0, progress=gr.Progress(track_tqdm=True)):
global training_state
if training_state and training_state.process:
return "Training already in progress"
@ -700,7 +765,7 @@ def run_training(config_path, verbose=False, buffer_size=8, keep_x_past_datasets
unload_whisper()
unload_voicefixer()
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets)
training_state = TrainingState(config_path=config_path, keep_x_past_datasets=keep_x_past_datasets, gpus=gpus)
for line in iter(training_state.process.stdout.readline, ""):
@ -723,11 +788,21 @@ def get_training_losses():
return
return pd.DataFrame(training_state.losses)
def update_training_dataplot():
def update_training_dataplot(config_path=None):
global training_state
if not training_state or not training_state.losses:
return
return gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
update = None
if not training_state:
if config_path:
training_state = TrainingState(config_path=config_path, start=False)
if training_state.losses:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
del training_state
training_state = None
elif training_state.losses:
update = gr.LinePlot.update(value=pd.DataFrame(training_state.losses))
return update
def reconnect_training(verbose=False, buffer_size=8, progress=gr.Progress(track_tqdm=True)):
global training_state
@ -823,8 +898,11 @@ def prepare_dataset( files, outdir, language=None, progress=None ):
torchaudio.save(f"{outdir}/{sliced_name}", sliced_waveform, sampling_rate)
transcription.append(f"{sliced_name}|{segment['text'].strip()}")
idx = idx + 1
line = f"{sliced_name}|{segment['text'].strip()}"
transcription.append(line)
with open(f'{outdir}/train.txt', 'a', encoding="utf-8") as f:
f.write(f'{line}\n')
with open(f'{outdir}/whisper.json', 'w', encoding="utf-8") as f:
f.write(json.dumps(results, indent='\t'))
@ -1035,32 +1113,6 @@ def get_voice_list(dir=get_voice_dir(), append_defaults=False):
res = res + ["random", "microphone"]
return res
def hash_file(path, algo="md5", buffer_size=0):
import hashlib
hash = None
if algo == "md5":
hash = hashlib.md5()
elif algo == "sha1":
hash = hashlib.sha1()
else:
raise Exception(f'Unknown hash algorithm specified: {algo}')
if not os.path.exists(path):
raise Exception(f'Path not found: {path}')
with open(path, 'rb') as f:
if buffer_size > 0:
while True:
data = f.read(buffer_size)
if not data:
break
hash.update(data)
else:
hash.update(f.read())
return "{0}".format(hash.hexdigest())
def get_autoregressive_models(dir="./models/finetunes/", prefixed=False):
os.makedirs(dir, exist_ok=True)
base = [get_model_path('autoregressive.pth')]
@ -1190,6 +1242,7 @@ def setup_args():
'defer-tts-load': False,
'device-override': None,
'prune-nonfinal-outputs': True,
'use-bigvgan-vocoder': True,
'concurrency-count': 2,
'output-sample-rate': 44100,
'output-volume': 1,
@ -1225,6 +1278,7 @@ def setup_args():
parser.add_argument("--force-cpu-for-conditioning-latents", default=default_arguments['force-cpu-for-conditioning-latents'], action='store_true', help="Forces computing conditional latents to be done on the CPU (if you constantyl OOM on low chunk counts)")
parser.add_argument("--defer-tts-load", default=default_arguments['defer-tts-load'], action='store_true', help="Defers loading TTS model")
parser.add_argument("--prune-nonfinal-outputs", default=default_arguments['prune-nonfinal-outputs'], action='store_true', help="Deletes non-final output files on completing a generation")
parser.add_argument("--use-bigvgan-vocoder", default=default_arguments['use-bigvgan-vocoder'], action='store_true', help="Uses BigVGAN in place of the default vocoder")
parser.add_argument("--device-override", default=default_arguments['device-override'], help="A device string to override pass through Torch")
parser.add_argument("--sample-batch-size", default=default_arguments['sample-batch-size'], type=int, help="Sets how many batches to use during the autoregressive samples pass")
parser.add_argument("--concurrency-count", type=int, default=default_arguments['concurrency-count'], help="How many Gradio events to process at once")
@ -1246,12 +1300,13 @@ def setup_args():
if not args.device_override:
set_device_name(args.device_override)
args.listen_host = None
args.listen_port = None
args.listen_path = None
if args.listen:
try:
match = re.findall(r"^(?:(.+?):(\d+))?(\/.+?)?$", args.listen)[0]
match = re.findall(r"^(?:(.+?):(\d+))?(\/.*?)?$", args.listen)[0]
args.listen_host = match[0] if match[0] != "" else "127.0.0.1"
args.listen_port = match[1] if match[1] != "" else None
@ -1264,7 +1319,7 @@ def setup_args():
return args
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
def update_args( listen, share, check_for_updates, models_from_local_only, low_vram, embed_output_metadata, latents_lean_and_mean, voice_fixer, voice_fixer_use_cuda, force_cpu_for_conditioning_latents, defer_tts_load, prune_nonfinal_outputs, use_bigvgan_vocoder, device_override, sample_batch_size, concurrency_count, output_sample_rate, output_volume, autoregressive_model, whisper_model, whisper_cpp, training_default_halfp, training_default_bnb ):
global args
args.listen = listen
@ -1275,6 +1330,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
args.force_cpu_for_conditioning_latents = force_cpu_for_conditioning_latents
args.defer_tts_load = defer_tts_load
args.prune_nonfinal_outputs = prune_nonfinal_outputs
args.use_bigvgan_vocoder = use_bigvgan_vocoder
args.device_override = device_override
args.sample_batch_size = sample_batch_size
args.embed_output_metadata = embed_output_metadata
@ -1297,7 +1353,7 @@ def update_args( listen, share, check_for_updates, models_from_local_only, low_v
def save_args_settings():
global args
settings = {
'listen': None if args.listen else args.listen,
'listen': None if not args.listen else args.listen,
'share': args.share,
'low-vram':args.low_vram,
'check-for-updates':args.check_for_updates,
@ -1305,6 +1361,7 @@ def save_args_settings():
'force-cpu-for-conditioning-latents': args.force_cpu_for_conditioning_latents,
'defer-tts-load': args.defer_tts_load,
'prune-nonfinal-outputs': args.prune_nonfinal_outputs,
'use-bigvgan-vocoder': args.use_bigvgan_vocoder,
'device-override': args.device_override,
'sample-batch-size': args.sample_batch_size,
'embed-output-metadata': args.embed_output_metadata,
@ -1419,7 +1476,7 @@ def load_tts( restart=False, model=None ):
tts_loading = True
try:
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model)
tts = TextToSpeech(minor_optimizations=not args.low_vram, autoregressive_model_path=args.autoregressive_model, use_bigvgan=args.use_bigvgan_vocoder)
except Exception as e:
tts = TextToSpeech(minor_optimizations=not args.low_vram)
load_autoregressive_model(args.autoregressive_model)
@ -1533,7 +1590,7 @@ def load_whisper_model(name=None, progress=None, language=b'en'):
notify_progress(f"Loading Whisper model: {args.whisper_model}", progress)
if args.whisper_cpp:
from whispercpp import Whisper
whisper_model = Whisper(name, models_dir='./models/', language=language)
whisper_model = Whisper(name, models_dir='./models/', language=language.encode('ascii'))
else:
import whisper
whisper_model = whisper.load_model(args.whisper_model)

View File

@ -527,16 +527,8 @@ def setup_gradio():
with gr.Row():
with gr.Column():
training_configs = gr.Dropdown(label="Training Configuration", choices=get_training_list())
refresh_configs = gr.Button(value="Refresh Configurations")
with gr.Row():
start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
reconnect_training_button = gr.Button(value="Reconnect")
with gr.Column():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
refresh_configs = gr.Button(value="Refresh Configurations")
training_loss_graph = gr.LinePlot(label="Training Metrics",
x="step",
@ -545,8 +537,20 @@ def setup_gradio():
color="type",
tooltip=['step', 'value', 'type'],
width=600,
height=350
height=350,
)
view_losses = gr.Button(value="View Losses")
with gr.Column():
training_output = gr.TextArea(label="Console Output", interactive=False, max_lines=8)
verbose_training = gr.Checkbox(label="Verbose Console Output", value=True)
training_buffer_size = gr.Slider(label="Console Buffer Size", minimum=4, maximum=32, value=8)
training_keep_x_past_datasets = gr.Slider(label="Keep X Previous States", minimum=0, maximum=8, value=0, step=1)
training_gpu_count = gr.Number(label="GPUs", value=1)
with gr.Row():
start_training_button = gr.Button(value="Train")
stop_training_button = gr.Button(value="Stop")
reconnect_training_button = gr.Button(value="Reconnect")
with gr.Tab("Settings"):
with gr.Row():
exec_inputs = []
@ -564,6 +568,7 @@ def setup_gradio():
gr.Checkbox(label="Force CPU for Conditioning Latents", value=args.force_cpu_for_conditioning_latents),
gr.Checkbox(label="Do Not Load TTS On Startup", value=args.defer_tts_load),
gr.Checkbox(label="Delete Non-Final Output", value=args.prune_nonfinal_outputs),
gr.Checkbox(label="Use BigVGAN Vocoder", value=args.use_bigvgan_vocoder),
gr.Textbox(label="Device Override", value=args.device_override),
]
with gr.Column():
@ -748,6 +753,7 @@ def setup_gradio():
inputs=[
training_configs,
verbose_training,
training_gpu_count,
training_buffer_size,
training_keep_x_past_datasets,
],
@ -763,6 +769,17 @@ def setup_gradio():
],
show_progress=False,
)
view_losses.click(
fn=update_training_dataplot,
inputs=[
training_configs
],
outputs=[
training_loss_graph,
],
)
stop_training_button.click(stop_training,
inputs=None,
outputs=training_output #console_output

View File

@ -1,5 +1,4 @@
call .\venv\Scripts\activate.bat
set PATH=.\bin\;%PATH%
python .\src\main.py %*
deactivate
pause

@ -1 +1 @@
Subproject commit 7cc0250a1a559da90965812fdefcba0d54a59c41
Subproject commit aca32a71f798ebd8487c113d41d1b4e9ee15c315

View File

@ -1,4 +1,4 @@
call .\venv\Scripts\activate.bat
python ./src/train.py -opt "%1"
deactivate
pause
deactivate

View File

@ -1,4 +1,13 @@
#!/bin/bash
source ./venv/bin/activate
python3 ./src/train.py -opt "$1"
GPUS=$1
CONFIG=$2
PORT=1234
if (( $GPUS > 1 )); then
python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT ./src/train.py -opt "$CONFIG" --launcher=pytorch
else
python3 ./src/train.py -opt "$CONFIG"
fi
deactivate

View File

@ -1,3 +1,15 @@
git fetch --all
git reset --hard origin/master
call .\update.bat
python -m venv venv
call .\venv\Scripts\activate.bat
python -m pip install --upgrade pip
python -m pip install -U -r .\dlas\requirements.txt
python -m pip install -U -r .\tortoise-tts\requirements.txt
python -m pip install -U -e .\tortoise-tts
python -m pip install -U -r .\requirements.txt
pause
deactivate

View File

@ -1,4 +1,17 @@
#!/bin/bash
git fetch --all
git reset --hard origin/master
./update.sh
# force install requirements
python3 -m venv venv
source ./venv/bin/activate
python3 -m pip install --upgrade pip
python3 -m pip install -r ./dlas/requirements.txt
python3 -m pip install -r ./tortoise-tts/requirements.txt
python3 -m pip install -e ./tortoise-tts
python3 -m pip install -r ./requirements.txt
deactivate

View File

@ -1,14 +1,2 @@
git pull
git submodule update --remote
python -m venv venv
call .\venv\Scripts\activate.bat
python -m pip install --upgrade pip
python -m pip install -r .\dlas\requirements.txt
python -m pip install -r .\tortoise-tts\requirements.txt
python -m pip install -e .\tortoise-tts
python -m pip install -r .\requirements.txt
deactivate
pause

View File

@ -2,14 +2,6 @@
git pull
git submodule update --remote
python3 -m venv venv
source ./venv/bin/activate
python3 -m pip install --upgrade pip
python3 -m pip install -r ./dlas/requirements.txt
python3 -m pip install -r ./tortoise-tts/requirements.txt
python3 -m pip install -e ./tortoise-tts
python3 -m pip install -r ./requirements.txt
if python -m pip show whispercpp &>/dev/null; then python -m pip install -U git+https://git.ecker.tech/lightmare/whispercpp.py; fi
deactivate