This commit is contained in:
mrq 2023-03-14 15:48:09 +00:00
parent c85e32ff53
commit 7b16b3e88a
5 changed files with 102 additions and 99 deletions

View File

@ -8,6 +8,9 @@ spkr_name_getter: "lambda p: p.parts[-3]"
model: ${model_name}
batch_size: ${batch_size}
eval_batch_size: ${validation_batch_size}
max_iter: ${iterations}
save_ckpt_every: ${save_rate}
eval_every: ${validation_rate}
sampling_temperature: 1.0

View File

@ -2,68 +2,48 @@ import os
import sys
import argparse
import yaml
import datetime
"""
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1'
if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1'
if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
"""
from torch.distributed.run import main as torchrun
# I don't want this invoked from an import
if __name__ != "__main__":
raise Exception("Do not invoke this from an import")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--mode', type=str, default='none', help='mode')
parser.add_argument('--yaml', type=str, help='Path to training configuration file.', default='./training/voice/train.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='Job launcher')
args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting
args.yaml = " ".join(args.yaml) # absolutely disgusting
config_path = args.yaml
with open(args.opt, 'r') as file:
with open(config_path, 'r') as file:
opt_config = yaml.safe_load(file)
if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]:
# it'd be downright sugoi if I was able to install DLAS as a pip package
sys.path.insert(0, './modules/dlas/codes/')
sys.path.insert(0, './modules/dlas/')
# yucky override
if "bitsandbytes" in opt_config and not opt_config["bitsandbytes"]:
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
sys.path.insert(0, './modules/dlas/codes/')
# this is also because DLAS is not written as a package in mind
# it'll gripe when it wants to import from train.py
sys.path.insert(0, './modules/dlas/')
# for PIP, replace it with:
# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__)))
# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../")
# don't even really bother trying to get DLAS PIP'd
# without kludge, it'll have to be accessible as `codes` and not `dlas`
import torch
import datetime
from codes import train as tr
from utils import util, options as option
from torch.distributed.run import main
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
# I'll clean it up better
def train(yaml, launcher='none'):
opt = option.parse(yaml, is_train=True)
def train(config_path, launcher='none'):
opt = option.parse(config_path, is_train=True)
if launcher == 'none' and opt['gpus'] > 1:
return main([f"--nproc_per_node={opt['gpus']}", "--master_port=1234", "./src/train.py", "-opt", yaml, "--launcher=pytorch"])
return torchrun([f"--nproc_per_node={opt['gpus']}", "./src/train.py", "--yaml", config_path, "--launcher=pytorch"])
trainer = tr.Trainer()
#### distributed training settings
if launcher == 'none': # disabled distributed training
if launcher == 'none':
opt['dist'] = False
trainer.rank = -1
if len(opt['gpu_ids']) == 1:
@ -76,10 +56,9 @@ def train(yaml, launcher='none'):
trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(yaml, opt, launcher, '')
trainer.init(config_path, opt, launcher, '')
trainer.do_training()
if __name__ == "__main__":
try:
import torch_intermediary
if torch_intermediary.OVERRIDE_ADAM:
@ -89,4 +68,4 @@ if __name__ == "__main__":
except Exception as e:
pass
train(args.opt, args.launcher)
train(config_path, args.launcher)

View File

@ -47,7 +47,7 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
TTSES = ['tortoise'] # + ['vall-e']
TTSES = ['tortoise']
GENERATE_SETTINGS_ARGS = None
@ -69,6 +69,9 @@ try:
except Exception as e:
pass
if VALLE_ENABLED:
TTSES.append('vall-e')
args = None
tts = None
tts_loading = False
@ -613,28 +616,41 @@ class TrainingState():
with open(config_path, 'r') as file:
self.config = yaml.safe_load(file)
gpus = self.config["gpus"]
self.killed = False
self.it = 0
self.step = 0
self.epoch = 0
self.checkpoint = 0
if args.tts_backend == "tortoise":
gpus = self.config["gpus"]
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
self.batch_size = self.config['datasets']['train']['batch_size']
self.dataset_path = self.config['datasets']['train']['path']
self.its = self.config['train']['niter']
self.steps = 1
self.epochs = 1 # int(self.its*self.batch_size/self.dataset_size)
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
elif args.tts_backend == "vall-e":
self.batch_size = self.config['batch_size']
self.dataset_dir = f".{self.config['data_root']}/finetune/"
self.dataset_path = f"{self.config['data_root']}/train.txt"
self.its = 1
self.steps = 1
self.epochs = 1
self.checkpoints = 1
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
gpus = self.json_config['gpus']
with open(self.dataset_path, 'r', encoding="utf-8") as f:
self.dataset_size = len(f.readlines())
self.it = 0
self.its = self.config['train']['niter']
self.step = 0
self.steps = 1
self.epoch = 0
self.epochs = int(self.its*self.batch_size/self.dataset_size)
self.checkpoint = 0
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
self.buffer = []
self.open_state = False
@ -672,6 +688,9 @@ class TrainingState():
self.spawn_process(config_path=config_path, gpus=gpus)
def spawn_process(self, config_path, gpus=1):
if args.tts_backend == "vall-e":
self.cmd = ['torchrun', '--nproc_per_node', f'{gpus}', '-m', 'vall_e.train', f'yaml="{config_path}"']
else:
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
print("Spawning process: ", " ".join(self.cmd))
@ -1221,6 +1240,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
lines = {
'training': [],
'validation': [],
'recordings': [],
'supervisions': [],
}
normalizer = EnglishTextNormalizer() if normalize else None
@ -1310,13 +1331,11 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
lines['training' if not culled else 'validation'].append(line)
if culled or not VALLE_ENABLED:
if culled or args.tts_backend != "vall-e":
continue
# VALL-E dataset
os.makedirs(f'{indir}/valle/', exist_ok=True)
try:
from vall_e.emb.qnt import encode as quantize
from vall_e.emb.g2p import encode as phonemize
@ -1329,10 +1348,6 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
phonemes = phonemize(normalized_text)
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes))
except Exception as e:
print(e)
pass
training_joined = "\n".join(lines['training'])
validation_joined = "\n".join(lines['validation'])
@ -1588,8 +1603,9 @@ def save_training_settings( **kwargs ):
with open(out, 'w', encoding="utf-8") as f:
f.write(yaml)
if args.tts_backend == "tortoise":
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
elif args.tts_backend == "vall-e":
settings['model_name'] = "ar"
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
settings['model_name'] = "nar"
@ -1692,8 +1708,13 @@ def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ])
def get_training_list(dir="./training/"):
if args.tts_backend == "tortoise":
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
return ars + nars
def pad(num, zeroes):
return str(num).zfill(zeroes+1)

View File

@ -1,5 +1,5 @@
call .\venv\Scripts\activate.bat
set PYTHONUTF8=1
python ./src/train.py -opt "%1"
python ./src/train.py --yaml "%1"
pause
deactivate

View File

@ -1,4 +1,4 @@
#!/bin/bash
source ./venv/bin/activate
python3 ./src/train.py -opt "$1"
python3 ./src/train.py --yaml "$1"
deactivate