forked from mrq/ai-voice-cloning
;)
This commit is contained in:
parent
c85e32ff53
commit
7b16b3e88a
|
@ -8,6 +8,9 @@ spkr_name_getter: "lambda p: p.parts[-3]"
|
||||||
model: ${model_name}
|
model: ${model_name}
|
||||||
batch_size: ${batch_size}
|
batch_size: ${batch_size}
|
||||||
eval_batch_size: ${validation_batch_size}
|
eval_batch_size: ${validation_batch_size}
|
||||||
|
|
||||||
|
max_iter: ${iterations}
|
||||||
|
save_ckpt_every: ${save_rate}
|
||||||
eval_every: ${validation_rate}
|
eval_every: ${validation_rate}
|
||||||
|
|
||||||
sampling_temperature: 1.0
|
sampling_temperature: 1.0
|
67
src/train.py
67
src/train.py
|
@ -2,68 +2,48 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import yaml
|
import yaml
|
||||||
|
import datetime
|
||||||
|
|
||||||
"""
|
from torch.distributed.run import main as torchrun
|
||||||
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
|
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
# I don't want this invoked from an import
|
||||||
if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ:
|
if __name__ != "__main__":
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1'
|
raise Exception("Do not invoke this from an import")
|
||||||
if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ:
|
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1'
|
|
||||||
if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
|
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
|
|
||||||
"""
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
parser.add_argument('--yaml', type=str, help='Path to training configuration file.', default='./training/voice/train.yml', nargs='+') # ugh
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='Job launcher')
|
||||||
parser.add_argument('--mode', type=str, default='none', help='mode')
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
args.yaml = " ".join(args.yaml) # absolutely disgusting
|
||||||
|
config_path = args.yaml
|
||||||
|
|
||||||
with open(args.opt, 'r') as file:
|
with open(config_path, 'r') as file:
|
||||||
opt_config = yaml.safe_load(file)
|
opt_config = yaml.safe_load(file)
|
||||||
|
|
||||||
if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]:
|
# it'd be downright sugoi if I was able to install DLAS as a pip package
|
||||||
|
sys.path.insert(0, './modules/dlas/codes/')
|
||||||
|
sys.path.insert(0, './modules/dlas/')
|
||||||
|
|
||||||
|
# yucky override
|
||||||
|
if "bitsandbytes" in opt_config and not opt_config["bitsandbytes"]:
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
|
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
|
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
|
||||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
|
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
|
||||||
|
|
||||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
|
||||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
|
||||||
|
|
||||||
sys.path.insert(0, './modules/dlas/codes/')
|
|
||||||
# this is also because DLAS is not written as a package in mind
|
|
||||||
# it'll gripe when it wants to import from train.py
|
|
||||||
sys.path.insert(0, './modules/dlas/')
|
|
||||||
|
|
||||||
# for PIP, replace it with:
|
|
||||||
# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__)))
|
|
||||||
# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../")
|
|
||||||
|
|
||||||
# don't even really bother trying to get DLAS PIP'd
|
|
||||||
# without kludge, it'll have to be accessible as `codes` and not `dlas`
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import datetime
|
|
||||||
from codes import train as tr
|
from codes import train as tr
|
||||||
from utils import util, options as option
|
from utils import util, options as option
|
||||||
from torch.distributed.run import main
|
|
||||||
|
|
||||||
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
||||||
# I'll clean it up better
|
def train(config_path, launcher='none'):
|
||||||
|
opt = option.parse(config_path, is_train=True)
|
||||||
def train(yaml, launcher='none'):
|
|
||||||
opt = option.parse(yaml, is_train=True)
|
|
||||||
|
|
||||||
if launcher == 'none' and opt['gpus'] > 1:
|
if launcher == 'none' and opt['gpus'] > 1:
|
||||||
return main([f"--nproc_per_node={opt['gpus']}", "--master_port=1234", "./src/train.py", "-opt", yaml, "--launcher=pytorch"])
|
return torchrun([f"--nproc_per_node={opt['gpus']}", "./src/train.py", "--yaml", config_path, "--launcher=pytorch"])
|
||||||
|
|
||||||
trainer = tr.Trainer()
|
trainer = tr.Trainer()
|
||||||
#### distributed training settings
|
if launcher == 'none':
|
||||||
if launcher == 'none': # disabled distributed training
|
|
||||||
opt['dist'] = False
|
opt['dist'] = False
|
||||||
trainer.rank = -1
|
trainer.rank = -1
|
||||||
if len(opt['gpu_ids']) == 1:
|
if len(opt['gpu_ids']) == 1:
|
||||||
|
@ -76,10 +56,9 @@ def train(yaml, launcher='none'):
|
||||||
trainer.rank = torch.distributed.get_rank()
|
trainer.rank = torch.distributed.get_rank()
|
||||||
torch.cuda.set_device(torch.distributed.get_rank())
|
torch.cuda.set_device(torch.distributed.get_rank())
|
||||||
|
|
||||||
trainer.init(yaml, opt, launcher, '')
|
trainer.init(config_path, opt, launcher, '')
|
||||||
trainer.do_training()
|
trainer.do_training()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
try:
|
try:
|
||||||
import torch_intermediary
|
import torch_intermediary
|
||||||
if torch_intermediary.OVERRIDE_ADAM:
|
if torch_intermediary.OVERRIDE_ADAM:
|
||||||
|
@ -89,4 +68,4 @@ if __name__ == "__main__":
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
train(args.opt, args.launcher)
|
train(config_path, args.launcher)
|
65
src/utils.py
65
src/utils.py
|
@ -47,7 +47,7 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
||||||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||||
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"]
|
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"]
|
||||||
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
||||||
TTSES = ['tortoise'] # + ['vall-e']
|
TTSES = ['tortoise']
|
||||||
|
|
||||||
GENERATE_SETTINGS_ARGS = None
|
GENERATE_SETTINGS_ARGS = None
|
||||||
|
|
||||||
|
@ -69,6 +69,9 @@ try:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if VALLE_ENABLED:
|
||||||
|
TTSES.append('vall-e')
|
||||||
|
|
||||||
args = None
|
args = None
|
||||||
tts = None
|
tts = None
|
||||||
tts_loading = False
|
tts_loading = False
|
||||||
|
@ -613,28 +616,41 @@ class TrainingState():
|
||||||
with open(config_path, 'r') as file:
|
with open(config_path, 'r') as file:
|
||||||
self.config = yaml.safe_load(file)
|
self.config = yaml.safe_load(file)
|
||||||
|
|
||||||
gpus = self.config["gpus"]
|
|
||||||
|
|
||||||
self.killed = False
|
self.killed = False
|
||||||
|
|
||||||
|
self.it = 0
|
||||||
|
self.step = 0
|
||||||
|
self.epoch = 0
|
||||||
|
self.checkpoint = 0
|
||||||
|
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
|
gpus = self.config["gpus"]
|
||||||
|
|
||||||
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
||||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||||
self.dataset_path = self.config['datasets']['train']['path']
|
self.dataset_path = self.config['datasets']['train']['path']
|
||||||
|
|
||||||
|
self.its = self.config['train']['niter']
|
||||||
|
self.steps = 1
|
||||||
|
self.epochs = 1 # int(self.its*self.batch_size/self.dataset_size)
|
||||||
|
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
||||||
|
elif args.tts_backend == "vall-e":
|
||||||
|
self.batch_size = self.config['batch_size']
|
||||||
|
self.dataset_dir = f".{self.config['data_root']}/finetune/"
|
||||||
|
self.dataset_path = f"{self.config['data_root']}/train.txt"
|
||||||
|
|
||||||
|
self.its = 1
|
||||||
|
self.steps = 1
|
||||||
|
self.epochs = 1
|
||||||
|
self.checkpoints = 1
|
||||||
|
|
||||||
|
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
|
||||||
|
gpus = self.json_config['gpus']
|
||||||
|
|
||||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||||
self.dataset_size = len(f.readlines())
|
self.dataset_size = len(f.readlines())
|
||||||
|
|
||||||
self.it = 0
|
|
||||||
self.its = self.config['train']['niter']
|
|
||||||
|
|
||||||
self.step = 0
|
|
||||||
self.steps = 1
|
|
||||||
|
|
||||||
self.epoch = 0
|
|
||||||
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
|
||||||
|
|
||||||
self.checkpoint = 0
|
|
||||||
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
|
||||||
|
|
||||||
self.buffer = []
|
self.buffer = []
|
||||||
|
|
||||||
self.open_state = False
|
self.open_state = False
|
||||||
|
@ -672,6 +688,9 @@ class TrainingState():
|
||||||
self.spawn_process(config_path=config_path, gpus=gpus)
|
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||||
|
|
||||||
def spawn_process(self, config_path, gpus=1):
|
def spawn_process(self, config_path, gpus=1):
|
||||||
|
if args.tts_backend == "vall-e":
|
||||||
|
self.cmd = ['torchrun', '--nproc_per_node', f'{gpus}', '-m', 'vall_e.train', f'yaml="{config_path}"']
|
||||||
|
else:
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
||||||
|
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
|
@ -1221,6 +1240,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
lines = {
|
lines = {
|
||||||
'training': [],
|
'training': [],
|
||||||
'validation': [],
|
'validation': [],
|
||||||
|
'recordings': [],
|
||||||
|
'supervisions': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizer = EnglishTextNormalizer() if normalize else None
|
normalizer = EnglishTextNormalizer() if normalize else None
|
||||||
|
@ -1310,13 +1331,11 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
|
|
||||||
lines['training' if not culled else 'validation'].append(line)
|
lines['training' if not culled else 'validation'].append(line)
|
||||||
|
|
||||||
if culled or not VALLE_ENABLED:
|
if culled or args.tts_backend != "vall-e":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# VALL-E dataset
|
|
||||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||||
|
|
||||||
try:
|
|
||||||
from vall_e.emb.qnt import encode as quantize
|
from vall_e.emb.qnt import encode as quantize
|
||||||
from vall_e.emb.g2p import encode as phonemize
|
from vall_e.emb.g2p import encode as phonemize
|
||||||
|
|
||||||
|
@ -1329,10 +1348,6 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
||||||
phonemes = phonemize(normalized_text)
|
phonemes = phonemize(normalized_text)
|
||||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes))
|
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes))
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
pass
|
|
||||||
|
|
||||||
training_joined = "\n".join(lines['training'])
|
training_joined = "\n".join(lines['training'])
|
||||||
validation_joined = "\n".join(lines['validation'])
|
validation_joined = "\n".join(lines['validation'])
|
||||||
|
|
||||||
|
@ -1588,8 +1603,9 @@ def save_training_settings( **kwargs ):
|
||||||
with open(out, 'w', encoding="utf-8") as f:
|
with open(out, 'w', encoding="utf-8") as f:
|
||||||
f.write(yaml)
|
f.write(yaml)
|
||||||
|
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
|
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
|
||||||
|
elif args.tts_backend == "vall-e":
|
||||||
settings['model_name'] = "ar"
|
settings['model_name'] = "ar"
|
||||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
|
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
|
||||||
settings['model_name'] = "nar"
|
settings['model_name'] = "nar"
|
||||||
|
@ -1692,8 +1708,13 @@ def get_dataset_list(dir="./training/"):
|
||||||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
|
||||||
def get_training_list(dir="./training/"):
|
def get_training_list(dir="./training/"):
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
|
||||||
|
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||||
|
return ars + nars
|
||||||
|
|
||||||
def pad(num, zeroes):
|
def pad(num, zeroes):
|
||||||
return str(num).zfill(zeroes+1)
|
return str(num).zfill(zeroes+1)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
call .\venv\Scripts\activate.bat
|
call .\venv\Scripts\activate.bat
|
||||||
set PYTHONUTF8=1
|
set PYTHONUTF8=1
|
||||||
python ./src/train.py -opt "%1"
|
python ./src/train.py --yaml "%1"
|
||||||
pause
|
pause
|
||||||
deactivate
|
deactivate
|
Loading…
Reference in New Issue
Block a user