From 7b16b3e88a9a81c84270e620d9f872067d053368 Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 14 Mar 2023 15:48:09 +0000 Subject: [PATCH] ;) --- models/.template.valle.yaml | 3 ++ src/train.py | 89 ++++++++++++------------------ src/utils.py | 105 +++++++++++++++++++++--------------- train.bat | 2 +- train.sh | 2 +- 5 files changed, 102 insertions(+), 99 deletions(-) diff --git a/models/.template.valle.yaml b/models/.template.valle.yaml index 2a3f57f..c9389ed 100755 --- a/models/.template.valle.yaml +++ b/models/.template.valle.yaml @@ -8,6 +8,9 @@ spkr_name_getter: "lambda p: p.parts[-3]" model: ${model_name} batch_size: ${batch_size} eval_batch_size: ${validation_batch_size} + +max_iter: ${iterations} +save_ckpt_every: ${save_rate} eval_every: ${validation_rate} sampling_temperature: 1.0 \ No newline at end of file diff --git a/src/train.py b/src/train.py index 39dee03..e93b76f 100755 --- a/src/train.py +++ b/src/train.py @@ -2,68 +2,48 @@ import os import sys import argparse import yaml +import datetime -""" -if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ: - os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' -if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ: - os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1' -if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ: - os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1' -if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ: - os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1' -""" +from torch.distributed.run import main as torchrun -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh - parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') - parser.add_argument('--mode', type=str, default='none', help='mode') - args = parser.parse_args() - args.opt = " ".join(args.opt) # absolutely disgusting - - with open(args.opt, 'r') as file: - opt_config = yaml.safe_load(file) +# I don't want this invoked from an import +if __name__ != "__main__": + raise Exception("Do not invoke this from an import") - if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]: - os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' - os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0' - os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0' - os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0' +parser = argparse.ArgumentParser() +parser.add_argument('--yaml', type=str, help='Path to training configuration file.', default='./training/voice/train.yml', nargs='+') # ugh +parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='Job launcher') +args = parser.parse_args() +args.yaml = " ".join(args.yaml) # absolutely disgusting +config_path = args.yaml -# this is some massive kludge that only works if it's called from a shell and not an import/PIP package -# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell +with open(config_path, 'r') as file: + opt_config = yaml.safe_load(file) +# it'd be downright sugoi if I was able to install DLAS as a pip package sys.path.insert(0, './modules/dlas/codes/') -# this is also because DLAS is not written as a package in mind -# it'll gripe when it wants to import from train.py sys.path.insert(0, './modules/dlas/') -# for PIP, replace it with: -# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__))) -# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../") - -# don't even really bother trying to get DLAS PIP'd -# without kludge, it'll have to be accessible as `codes` and not `dlas` +# yucky override +if "bitsandbytes" in opt_config and not opt_config["bitsandbytes"]: + os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' + os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0' + os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0' + os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0' import torch -import datetime from codes import train as tr from utils import util, options as option -from torch.distributed.run import main # this is effectively just copy pasted and cleaned up from the __main__ section of training.py -# I'll clean it up better - -def train(yaml, launcher='none'): - opt = option.parse(yaml, is_train=True) +def train(config_path, launcher='none'): + opt = option.parse(config_path, is_train=True) if launcher == 'none' and opt['gpus'] > 1: - return main([f"--nproc_per_node={opt['gpus']}", "--master_port=1234", "./src/train.py", "-opt", yaml, "--launcher=pytorch"]) + return torchrun([f"--nproc_per_node={opt['gpus']}", "./src/train.py", "--yaml", config_path, "--launcher=pytorch"]) trainer = tr.Trainer() - #### distributed training settings - if launcher == 'none': # disabled distributed training + if launcher == 'none': opt['dist'] = False trainer.rank = -1 if len(opt['gpu_ids']) == 1: @@ -76,17 +56,16 @@ def train(yaml, launcher='none'): trainer.rank = torch.distributed.get_rank() torch.cuda.set_device(torch.distributed.get_rank()) - trainer.init(yaml, opt, launcher, '') + trainer.init(config_path, opt, launcher, '') trainer.do_training() -if __name__ == "__main__": - try: - import torch_intermediary - if torch_intermediary.OVERRIDE_ADAM: - print("Using BitsAndBytes optimizations") - else: - print("NOT using BitsAndBytes optimizations") - except Exception as e: - pass +try: + import torch_intermediary + if torch_intermediary.OVERRIDE_ADAM: + print("Using BitsAndBytes optimizations") + else: + print("NOT using BitsAndBytes optimizations") +except Exception as e: + pass - train(args.opt, args.launcher) \ No newline at end of file +train(config_path, args.launcher) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index d20b6b5..f7e7577 100755 --- a/src/utils.py +++ b/src/utils.py @@ -47,7 +47,7 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] -TTSES = ['tortoise'] # + ['vall-e'] +TTSES = ['tortoise'] GENERATE_SETTINGS_ARGS = None @@ -69,6 +69,9 @@ try: except Exception as e: pass +if VALLE_ENABLED: + TTSES.append('vall-e') + args = None tts = None tts_loading = False @@ -613,28 +616,41 @@ class TrainingState(): with open(config_path, 'r') as file: self.config = yaml.safe_load(file) - gpus = self.config["gpus"] self.killed = False + + self.it = 0 + self.step = 0 + self.epoch = 0 + self.checkpoint = 0 + + if args.tts_backend == "tortoise": + gpus = self.config["gpus"] + + self.dataset_dir = f"./training/{self.config['name']}/finetune/" + self.batch_size = self.config['datasets']['train']['batch_size'] + self.dataset_path = self.config['datasets']['train']['path'] + + self.its = self.config['train']['niter'] + self.steps = 1 + self.epochs = 1 # int(self.its*self.batch_size/self.dataset_size) + self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq']) + elif args.tts_backend == "vall-e": + self.batch_size = self.config['batch_size'] + self.dataset_dir = f".{self.config['data_root']}/finetune/" + self.dataset_path = f"{self.config['data_root']}/train.txt" + + self.its = 1 + self.steps = 1 + self.epochs = 1 + self.checkpoints = 1 + + self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8")) + gpus = self.json_config['gpus'] - self.dataset_dir = f"./training/{self.config['name']}/finetune/" - self.batch_size = self.config['datasets']['train']['batch_size'] - self.dataset_path = self.config['datasets']['train']['path'] with open(self.dataset_path, 'r', encoding="utf-8") as f: self.dataset_size = len(f.readlines()) - self.it = 0 - self.its = self.config['train']['niter'] - - self.step = 0 - self.steps = 1 - - self.epoch = 0 - self.epochs = int(self.its*self.batch_size/self.dataset_size) - - self.checkpoint = 0 - self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq']) - self.buffer = [] self.open_state = False @@ -672,7 +688,10 @@ class TrainingState(): self.spawn_process(config_path=config_path, gpus=gpus) def spawn_process(self, config_path, gpus=1): - self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path] + if args.tts_backend == "vall-e": + self.cmd = ['torchrun', '--nproc_per_node', f'{gpus}', '-m', 'vall_e.train', f'yaml="{config_path}"'] + else: + self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path] print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) @@ -1221,6 +1240,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T lines = { 'training': [], 'validation': [], + 'recordings': [], + 'supervisions': [], } normalizer = EnglishTextNormalizer() if normalize else None @@ -1310,28 +1331,22 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T lines['training' if not culled else 'validation'].append(line) - if culled or not VALLE_ENABLED: + if culled or args.tts_backend != "vall-e": continue - # VALL-E dataset os.makedirs(f'{indir}/valle/', exist_ok=True) - try: - from vall_e.emb.qnt import encode as quantize - from vall_e.emb.g2p import encode as phonemize - - if waveform.shape[0] == 2: - waveform = wav[:1] + from vall_e.emb.qnt import encode as quantize + from vall_e.emb.g2p import encode as phonemize + + if waveform.shape[0] == 2: + waveform = wav[:1] - quantized = quantize( waveform, sample_rate ).cpu() - torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') - - phonemes = phonemize(normalized_text) - open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes)) - - except Exception as e: - print(e) - pass + quantized = quantize( waveform, sample_rate ).cpu() + torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') + + phonemes = phonemize(normalized_text) + open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes)) training_joined = "\n".join(lines['training']) validation_joined = "\n".join(lines['validation']) @@ -1588,12 +1603,13 @@ def save_training_settings( **kwargs ): with open(out, 'w', encoding="utf-8") as f: f.write(yaml) - use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml') - - settings['model_name'] = "ar" - use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml') - settings['model_name'] = "nar" - use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml') + if args.tts_backend == "tortoise": + use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml') + elif args.tts_backend == "vall-e": + settings['model_name'] = "ar" + use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml') + settings['model_name'] = "nar" + use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml') messages.append(f"Saved training output") return settings, messages @@ -1692,7 +1708,12 @@ def get_dataset_list(dir="./training/"): return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ]) def get_training_list(dir="./training/"): - return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) + if args.tts_backend == "tortoise": + return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) + + ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ]) + nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ]) + return ars + nars def pad(num, zeroes): return str(num).zfill(zeroes+1) diff --git a/train.bat b/train.bat index ada6bc2..70a3d3c 100755 --- a/train.bat +++ b/train.bat @@ -1,5 +1,5 @@ call .\venv\Scripts\activate.bat set PYTHONUTF8=1 -python ./src/train.py -opt "%1" +python ./src/train.py --yaml "%1" pause deactivate \ No newline at end of file diff --git a/train.sh b/train.sh index 5e83e27..b112e45 100755 --- a/train.sh +++ b/train.sh @@ -1,4 +1,4 @@ #!/bin/bash source ./venv/bin/activate -python3 ./src/train.py -opt "$1" +python3 ./src/train.py --yaml "$1" deactivate