This commit is contained in:
mrq 2023-03-14 15:48:09 +00:00
parent c85e32ff53
commit 7b16b3e88a
5 changed files with 102 additions and 99 deletions

View File

@ -8,6 +8,9 @@ spkr_name_getter: "lambda p: p.parts[-3]"
model: ${model_name} model: ${model_name}
batch_size: ${batch_size} batch_size: ${batch_size}
eval_batch_size: ${validation_batch_size} eval_batch_size: ${validation_batch_size}
max_iter: ${iterations}
save_ckpt_every: ${save_rate}
eval_every: ${validation_rate} eval_every: ${validation_rate}
sampling_temperature: 1.0 sampling_temperature: 1.0

View File

@ -2,68 +2,48 @@ import os
import sys import sys
import argparse import argparse
import yaml import yaml
import datetime
""" from torch.distributed.run import main as torchrun
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' # I don't want this invoked from an import
if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ: if __name__ != "__main__":
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1' raise Exception("Do not invoke this from an import")
if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1'
if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
"""
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh parser.add_argument('--yaml', type=str, help='Path to training configuration file.', default='./training/voice/train.yml', nargs='+') # ugh
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='Job launcher')
parser.add_argument('--mode', type=str, default='none', help='mode')
args = parser.parse_args() args = parser.parse_args()
args.opt = " ".join(args.opt) # absolutely disgusting args.yaml = " ".join(args.yaml) # absolutely disgusting
config_path = args.yaml
with open(args.opt, 'r') as file: with open(config_path, 'r') as file:
opt_config = yaml.safe_load(file) opt_config = yaml.safe_load(file)
if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]: # it'd be downright sugoi if I was able to install DLAS as a pip package
sys.path.insert(0, './modules/dlas/codes/')
sys.path.insert(0, './modules/dlas/')
# yucky override
if "bitsandbytes" in opt_config and not opt_config["bitsandbytes"]:
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0' os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0' os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0' os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0' os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
sys.path.insert(0, './modules/dlas/codes/')
# this is also because DLAS is not written as a package in mind
# it'll gripe when it wants to import from train.py
sys.path.insert(0, './modules/dlas/')
# for PIP, replace it with:
# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__)))
# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../")
# don't even really bother trying to get DLAS PIP'd
# without kludge, it'll have to be accessible as `codes` and not `dlas`
import torch import torch
import datetime
from codes import train as tr from codes import train as tr
from utils import util, options as option from utils import util, options as option
from torch.distributed.run import main
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py # this is effectively just copy pasted and cleaned up from the __main__ section of training.py
# I'll clean it up better def train(config_path, launcher='none'):
opt = option.parse(config_path, is_train=True)
def train(yaml, launcher='none'):
opt = option.parse(yaml, is_train=True)
if launcher == 'none' and opt['gpus'] > 1: if launcher == 'none' and opt['gpus'] > 1:
return main([f"--nproc_per_node={opt['gpus']}", "--master_port=1234", "./src/train.py", "-opt", yaml, "--launcher=pytorch"]) return torchrun([f"--nproc_per_node={opt['gpus']}", "./src/train.py", "--yaml", config_path, "--launcher=pytorch"])
trainer = tr.Trainer() trainer = tr.Trainer()
#### distributed training settings if launcher == 'none':
if launcher == 'none': # disabled distributed training
opt['dist'] = False opt['dist'] = False
trainer.rank = -1 trainer.rank = -1
if len(opt['gpu_ids']) == 1: if len(opt['gpu_ids']) == 1:
@ -76,10 +56,9 @@ def train(yaml, launcher='none'):
trainer.rank = torch.distributed.get_rank() trainer.rank = torch.distributed.get_rank()
torch.cuda.set_device(torch.distributed.get_rank()) torch.cuda.set_device(torch.distributed.get_rank())
trainer.init(yaml, opt, launcher, '') trainer.init(config_path, opt, launcher, '')
trainer.do_training() trainer.do_training()
if __name__ == "__main__":
try: try:
import torch_intermediary import torch_intermediary
if torch_intermediary.OVERRIDE_ADAM: if torch_intermediary.OVERRIDE_ADAM:
@ -89,4 +68,4 @@ if __name__ == "__main__":
except Exception as e: except Exception as e:
pass pass
train(args.opt, args.launcher) train(config_path, args.launcher)

View File

@ -47,7 +47,7 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"] WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"] WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"]
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band'] VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
TTSES = ['tortoise'] # + ['vall-e'] TTSES = ['tortoise']
GENERATE_SETTINGS_ARGS = None GENERATE_SETTINGS_ARGS = None
@ -69,6 +69,9 @@ try:
except Exception as e: except Exception as e:
pass pass
if VALLE_ENABLED:
TTSES.append('vall-e')
args = None args = None
tts = None tts = None
tts_loading = False tts_loading = False
@ -613,28 +616,41 @@ class TrainingState():
with open(config_path, 'r') as file: with open(config_path, 'r') as file:
self.config = yaml.safe_load(file) self.config = yaml.safe_load(file)
gpus = self.config["gpus"]
self.killed = False self.killed = False
self.it = 0
self.step = 0
self.epoch = 0
self.checkpoint = 0
if args.tts_backend == "tortoise":
gpus = self.config["gpus"]
self.dataset_dir = f"./training/{self.config['name']}/finetune/" self.dataset_dir = f"./training/{self.config['name']}/finetune/"
self.batch_size = self.config['datasets']['train']['batch_size'] self.batch_size = self.config['datasets']['train']['batch_size']
self.dataset_path = self.config['datasets']['train']['path'] self.dataset_path = self.config['datasets']['train']['path']
self.its = self.config['train']['niter']
self.steps = 1
self.epochs = 1 # int(self.its*self.batch_size/self.dataset_size)
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
elif args.tts_backend == "vall-e":
self.batch_size = self.config['batch_size']
self.dataset_dir = f".{self.config['data_root']}/finetune/"
self.dataset_path = f"{self.config['data_root']}/train.txt"
self.its = 1
self.steps = 1
self.epochs = 1
self.checkpoints = 1
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
gpus = self.json_config['gpus']
with open(self.dataset_path, 'r', encoding="utf-8") as f: with open(self.dataset_path, 'r', encoding="utf-8") as f:
self.dataset_size = len(f.readlines()) self.dataset_size = len(f.readlines())
self.it = 0
self.its = self.config['train']['niter']
self.step = 0
self.steps = 1
self.epoch = 0
self.epochs = int(self.its*self.batch_size/self.dataset_size)
self.checkpoint = 0
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
self.buffer = [] self.buffer = []
self.open_state = False self.open_state = False
@ -672,6 +688,9 @@ class TrainingState():
self.spawn_process(config_path=config_path, gpus=gpus) self.spawn_process(config_path=config_path, gpus=gpus)
def spawn_process(self, config_path, gpus=1): def spawn_process(self, config_path, gpus=1):
if args.tts_backend == "vall-e":
self.cmd = ['torchrun', '--nproc_per_node', f'{gpus}', '-m', 'vall_e.train', f'yaml="{config_path}"']
else:
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path] self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
print("Spawning process: ", " ".join(self.cmd)) print("Spawning process: ", " ".join(self.cmd))
@ -1221,6 +1240,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
lines = { lines = {
'training': [], 'training': [],
'validation': [], 'validation': [],
'recordings': [],
'supervisions': [],
} }
normalizer = EnglishTextNormalizer() if normalize else None normalizer = EnglishTextNormalizer() if normalize else None
@ -1310,13 +1331,11 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
lines['training' if not culled else 'validation'].append(line) lines['training' if not culled else 'validation'].append(line)
if culled or not VALLE_ENABLED: if culled or args.tts_backend != "vall-e":
continue continue
# VALL-E dataset
os.makedirs(f'{indir}/valle/', exist_ok=True) os.makedirs(f'{indir}/valle/', exist_ok=True)
try:
from vall_e.emb.qnt import encode as quantize from vall_e.emb.qnt import encode as quantize
from vall_e.emb.g2p import encode as phonemize from vall_e.emb.g2p import encode as phonemize
@ -1329,10 +1348,6 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
phonemes = phonemize(normalized_text) phonemes = phonemize(normalized_text)
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes)) open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes))
except Exception as e:
print(e)
pass
training_joined = "\n".join(lines['training']) training_joined = "\n".join(lines['training'])
validation_joined = "\n".join(lines['validation']) validation_joined = "\n".join(lines['validation'])
@ -1588,8 +1603,9 @@ def save_training_settings( **kwargs ):
with open(out, 'w', encoding="utf-8") as f: with open(out, 'w', encoding="utf-8") as f:
f.write(yaml) f.write(yaml)
if args.tts_backend == "tortoise":
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml') use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
elif args.tts_backend == "vall-e":
settings['model_name'] = "ar" settings['model_name'] = "ar"
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml') use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
settings['model_name'] = "nar" settings['model_name'] = "nar"
@ -1692,8 +1708,13 @@ def get_dataset_list(dir="./training/"):
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ]) return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ])
def get_training_list(dir="./training/"): def get_training_list(dir="./training/"):
if args.tts_backend == "tortoise":
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ]) return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
return ars + nars
def pad(num, zeroes): def pad(num, zeroes):
return str(num).zfill(zeroes+1) return str(num).zfill(zeroes+1)

View File

@ -1,5 +1,5 @@
call .\venv\Scripts\activate.bat call .\venv\Scripts\activate.bat
set PYTHONUTF8=1 set PYTHONUTF8=1
python ./src/train.py -opt "%1" python ./src/train.py --yaml "%1"
pause pause
deactivate deactivate

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/bin/bash
source ./venv/bin/activate source ./venv/bin/activate
python3 ./src/train.py -opt "$1" python3 ./src/train.py --yaml "$1"
deactivate deactivate