forked from mrq/ai-voice-cloning
;)
This commit is contained in:
parent
c85e32ff53
commit
7b16b3e88a
|
@ -8,6 +8,9 @@ spkr_name_getter: "lambda p: p.parts[-3]"
|
|||
model: ${model_name}
|
||||
batch_size: ${batch_size}
|
||||
eval_batch_size: ${validation_batch_size}
|
||||
|
||||
max_iter: ${iterations}
|
||||
save_ckpt_every: ${save_rate}
|
||||
eval_every: ${validation_rate}
|
||||
|
||||
sampling_temperature: 1.0
|
89
src/train.py
89
src/train.py
|
@ -2,68 +2,48 @@ import os
|
|||
import sys
|
||||
import argparse
|
||||
import yaml
|
||||
import datetime
|
||||
|
||||
"""
|
||||
if 'BITSANDBYTES_OVERRIDE_LINEAR' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||
if 'BITSANDBYTES_OVERRIDE_EMBEDDING' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '1'
|
||||
if 'BITSANDBYTES_OVERRIDE_ADAM' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '1'
|
||||
if 'BITSANDBYTES_OVERRIDE_ADAMW' not in os.environ:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '1'
|
||||
"""
|
||||
from torch.distributed.run import main as torchrun
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_vit_latent.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--mode', type=str, default='none', help='mode')
|
||||
args = parser.parse_args()
|
||||
args.opt = " ".join(args.opt) # absolutely disgusting
|
||||
# I don't want this invoked from an import
|
||||
if __name__ != "__main__":
|
||||
raise Exception("Do not invoke this from an import")
|
||||
|
||||
with open(args.opt, 'r') as file:
|
||||
opt_config = yaml.safe_load(file)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--yaml', type=str, help='Path to training configuration file.', default='./training/voice/train.yml', nargs='+') # ugh
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='Job launcher')
|
||||
args = parser.parse_args()
|
||||
args.yaml = " ".join(args.yaml) # absolutely disgusting
|
||||
config_path = args.yaml
|
||||
|
||||
if "ext" in opt_config and "bitsandbytes" in opt_config["ext"] and not opt_config["ext"]["bitsandbytes"]:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
|
||||
|
||||
# this is some massive kludge that only works if it's called from a shell and not an import/PIP package
|
||||
# it's smart-yet-irritating module-model loader breaks when trying to load something specifically when not from a shell
|
||||
with open(config_path, 'r') as file:
|
||||
opt_config = yaml.safe_load(file)
|
||||
|
||||
# it'd be downright sugoi if I was able to install DLAS as a pip package
|
||||
sys.path.insert(0, './modules/dlas/codes/')
|
||||
# this is also because DLAS is not written as a package in mind
|
||||
# it'll gripe when it wants to import from train.py
|
||||
sys.path.insert(0, './modules/dlas/')
|
||||
|
||||
# for PIP, replace it with:
|
||||
# sys.path.insert(0, os.path.dirname(os.path.realpath(dlas.__file__)))
|
||||
# sys.path.insert(0, f"{os.path.dirname(os.path.realpath(dlas.__file__))}/../")
|
||||
|
||||
# don't even really bother trying to get DLAS PIP'd
|
||||
# without kludge, it'll have to be accessible as `codes` and not `dlas`
|
||||
# yucky override
|
||||
if "bitsandbytes" in opt_config and not opt_config["bitsandbytes"]:
|
||||
os.environ['BITSANDBYTES_OVERRIDE_LINEAR'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_EMBEDDING'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAM'] = '0'
|
||||
os.environ['BITSANDBYTES_OVERRIDE_ADAMW'] = '0'
|
||||
|
||||
import torch
|
||||
import datetime
|
||||
from codes import train as tr
|
||||
from utils import util, options as option
|
||||
from torch.distributed.run import main
|
||||
|
||||
# this is effectively just copy pasted and cleaned up from the __main__ section of training.py
|
||||
# I'll clean it up better
|
||||
|
||||
def train(yaml, launcher='none'):
|
||||
opt = option.parse(yaml, is_train=True)
|
||||
def train(config_path, launcher='none'):
|
||||
opt = option.parse(config_path, is_train=True)
|
||||
|
||||
if launcher == 'none' and opt['gpus'] > 1:
|
||||
return main([f"--nproc_per_node={opt['gpus']}", "--master_port=1234", "./src/train.py", "-opt", yaml, "--launcher=pytorch"])
|
||||
return torchrun([f"--nproc_per_node={opt['gpus']}", "./src/train.py", "--yaml", config_path, "--launcher=pytorch"])
|
||||
|
||||
trainer = tr.Trainer()
|
||||
#### distributed training settings
|
||||
if launcher == 'none': # disabled distributed training
|
||||
if launcher == 'none':
|
||||
opt['dist'] = False
|
||||
trainer.rank = -1
|
||||
if len(opt['gpu_ids']) == 1:
|
||||
|
@ -76,17 +56,16 @@ def train(yaml, launcher='none'):
|
|||
trainer.rank = torch.distributed.get_rank()
|
||||
torch.cuda.set_device(torch.distributed.get_rank())
|
||||
|
||||
trainer.init(yaml, opt, launcher, '')
|
||||
trainer.init(config_path, opt, launcher, '')
|
||||
trainer.do_training()
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
import torch_intermediary
|
||||
if torch_intermediary.OVERRIDE_ADAM:
|
||||
print("Using BitsAndBytes optimizations")
|
||||
else:
|
||||
print("NOT using BitsAndBytes optimizations")
|
||||
except Exception as e:
|
||||
pass
|
||||
try:
|
||||
import torch_intermediary
|
||||
if torch_intermediary.OVERRIDE_ADAM:
|
||||
print("Using BitsAndBytes optimizations")
|
||||
else:
|
||||
print("NOT using BitsAndBytes optimizations")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
train(args.opt, args.launcher)
|
||||
train(config_path, args.launcher)
|
101
src/utils.py
101
src/utils.py
|
@ -47,7 +47,7 @@ WHISPER_MODELS = ["tiny", "base", "small", "medium", "large"]
|
|||
WHISPER_SPECIALIZED_MODELS = ["tiny.en", "base.en", "small.en", "medium.en"]
|
||||
WHISPER_BACKENDS = ["openai/whisper", "lightmare/whispercpp"]
|
||||
VOCODERS = ['univnet', 'bigvgan_base_24khz_100band', 'bigvgan_24khz_100band']
|
||||
TTSES = ['tortoise'] # + ['vall-e']
|
||||
TTSES = ['tortoise']
|
||||
|
||||
GENERATE_SETTINGS_ARGS = None
|
||||
|
||||
|
@ -69,6 +69,9 @@ try:
|
|||
except Exception as e:
|
||||
pass
|
||||
|
||||
if VALLE_ENABLED:
|
||||
TTSES.append('vall-e')
|
||||
|
||||
args = None
|
||||
tts = None
|
||||
tts_loading = False
|
||||
|
@ -613,28 +616,41 @@ class TrainingState():
|
|||
with open(config_path, 'r') as file:
|
||||
self.config = yaml.safe_load(file)
|
||||
|
||||
gpus = self.config["gpus"]
|
||||
|
||||
self.killed = False
|
||||
|
||||
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||
self.dataset_path = self.config['datasets']['train']['path']
|
||||
self.it = 0
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.checkpoint = 0
|
||||
|
||||
if args.tts_backend == "tortoise":
|
||||
gpus = self.config["gpus"]
|
||||
|
||||
self.dataset_dir = f"./training/{self.config['name']}/finetune/"
|
||||
self.batch_size = self.config['datasets']['train']['batch_size']
|
||||
self.dataset_path = self.config['datasets']['train']['path']
|
||||
|
||||
self.its = self.config['train']['niter']
|
||||
self.steps = 1
|
||||
self.epochs = 1 # int(self.its*self.batch_size/self.dataset_size)
|
||||
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
||||
elif args.tts_backend == "vall-e":
|
||||
self.batch_size = self.config['batch_size']
|
||||
self.dataset_dir = f".{self.config['data_root']}/finetune/"
|
||||
self.dataset_path = f"{self.config['data_root']}/train.txt"
|
||||
|
||||
self.its = 1
|
||||
self.steps = 1
|
||||
self.epochs = 1
|
||||
self.checkpoints = 1
|
||||
|
||||
self.json_config = json.load(open(f"{self.config['data_root']}/train.json", 'r', encoding="utf-8"))
|
||||
gpus = self.json_config['gpus']
|
||||
|
||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||
self.dataset_size = len(f.readlines())
|
||||
|
||||
self.it = 0
|
||||
self.its = self.config['train']['niter']
|
||||
|
||||
self.step = 0
|
||||
self.steps = 1
|
||||
|
||||
self.epoch = 0
|
||||
self.epochs = int(self.its*self.batch_size/self.dataset_size)
|
||||
|
||||
self.checkpoint = 0
|
||||
self.checkpoints = int(self.its / self.config['logger']['save_checkpoint_freq'])
|
||||
|
||||
self.buffer = []
|
||||
|
||||
self.open_state = False
|
||||
|
@ -672,7 +688,10 @@ class TrainingState():
|
|||
self.spawn_process(config_path=config_path, gpus=gpus)
|
||||
|
||||
def spawn_process(self, config_path, gpus=1):
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
||||
if args.tts_backend == "vall-e":
|
||||
self.cmd = ['torchrun', '--nproc_per_node', f'{gpus}', '-m', 'vall_e.train', f'yaml="{config_path}"']
|
||||
else:
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
||||
|
||||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
@ -1221,6 +1240,8 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
lines = {
|
||||
'training': [],
|
||||
'validation': [],
|
||||
'recordings': [],
|
||||
'supervisions': [],
|
||||
}
|
||||
|
||||
normalizer = EnglishTextNormalizer() if normalize else None
|
||||
|
@ -1310,28 +1331,22 @@ def prepare_dataset( voice, use_segments, text_length, audio_length, normalize=T
|
|||
|
||||
lines['training' if not culled else 'validation'].append(line)
|
||||
|
||||
if culled or not VALLE_ENABLED:
|
||||
if culled or args.tts_backend != "vall-e":
|
||||
continue
|
||||
|
||||
# VALL-E dataset
|
||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||
|
||||
try:
|
||||
from vall_e.emb.qnt import encode as quantize
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
from vall_e.emb.qnt import encode as quantize
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
|
||||
if waveform.shape[0] == 2:
|
||||
waveform = wav[:1]
|
||||
if waveform.shape[0] == 2:
|
||||
waveform = wav[:1]
|
||||
|
||||
quantized = quantize( waveform, sample_rate ).cpu()
|
||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||
quantized = quantize( waveform, sample_rate ).cpu()
|
||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||
|
||||
phonemes = phonemize(normalized_text)
|
||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes))
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
phonemes = phonemize(normalized_text)
|
||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemes))
|
||||
|
||||
training_joined = "\n".join(lines['training'])
|
||||
validation_joined = "\n".join(lines['validation'])
|
||||
|
@ -1588,12 +1603,13 @@ def save_training_settings( **kwargs ):
|
|||
with open(out, 'w', encoding="utf-8") as f:
|
||||
f.write(yaml)
|
||||
|
||||
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
|
||||
|
||||
settings['model_name'] = "ar"
|
||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
|
||||
settings['model_name'] = "nar"
|
||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml')
|
||||
if args.tts_backend == "tortoise":
|
||||
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
|
||||
elif args.tts_backend == "vall-e":
|
||||
settings['model_name'] = "ar"
|
||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
|
||||
settings['model_name'] = "nar"
|
||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml')
|
||||
|
||||
messages.append(f"Saved training output")
|
||||
return settings, messages
|
||||
|
@ -1692,7 +1708,12 @@ def get_dataset_list(dir="./training/"):
|
|||
return sorted([d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.txt" in os.listdir(os.path.join(dir, d)) ])
|
||||
|
||||
def get_training_list(dir="./training/"):
|
||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
if args.tts_backend == "tortoise":
|
||||
return sorted([f'./training/{d}/train.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "train.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
|
||||
ars = sorted([f'./training/{d}/ar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "ar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
nars = sorted([f'./training/{d}/nar.yaml' for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d)) and "nar.yaml" in os.listdir(os.path.join(dir, d)) ])
|
||||
return ars + nars
|
||||
|
||||
def pad(num, zeroes):
|
||||
return str(num).zfill(zeroes+1)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
call .\venv\Scripts\activate.bat
|
||||
set PYTHONUTF8=1
|
||||
python ./src/train.py -opt "%1"
|
||||
python ./src/train.py --yaml "%1"
|
||||
pause
|
||||
deactivate
|
Loading…
Reference in New Issue
Block a user