cleanups, realigning vall-e training

This commit is contained in:
mrq 2023-03-22 17:47:23 +00:00
parent 909325bb5a
commit f822c87344
2 changed files with 73 additions and 89 deletions

View File

@ -5,7 +5,9 @@ log_root: ./training/${voice}/finetune/logs/
data_dirs: [./training/${voice}/valle/]
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
model: ${model_name}
max_phones: 72
models: '${models}'
batch_size: ${batch_size}
gradient_accumulation_steps: ${gradient_accumulation_size}
eval_batch_size: ${batch_size}
@ -13,7 +15,3 @@ eval_batch_size: ${batch_size}
max_iter: ${iterations}
save_ckpt_every: ${save_rate}
eval_every: ${validation_rate}
max_phones: 256
sampling_temperature: 1.0

View File

@ -642,7 +642,6 @@ class TrainingState():
self.yaml_config = yaml.safe_load(file)
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
self.dataset_dir = f"{self.training_dir}/finetune/"
self.dataset_path = f"{self.training_dir}/train.txt"
with open(self.dataset_path, 'r', encoding="utf-8") as f:
self.dataset_size = len(f.readlines())
@ -690,9 +689,6 @@ class TrainingState():
'loss': "",
}
self.buffer_json = None
self.json_buffer = []
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
if keep_x_past_checkpoints > 0:
@ -704,18 +700,18 @@ class TrainingState():
if args.tts_backend == "vall-e":
self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
else:
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
self.cmd = [f'train.{"bat" if os.name == "nt" else "sh"}', config_path]
print("Spawning process: ", " ".join(self.cmd))
self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
def parse_metrics(self, data):
if isinstance(data, str):
if line.find('INFO: Training Metrics:') >= 0:
data = json.loads(line.split("INFO: Training Metrics:")[-1])
if line.find('Training Metrics:') >= 0:
data = json.loads(line.split("Training Metrics:")[-1])
data['mode'] = "training"
elif line.find('INFO: Validation Metrics:') >= 0:
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
data['mode'] = "validation"
else:
return
@ -755,21 +751,19 @@ class TrainingState():
self.metrics['step'] = ", ".join(self.metrics['step'])
epoch = self.epoch + (self.step / self.steps)
if 'lr' in self.info:
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
if args.tts_backend == "tortoise":
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
if k not in self.info:
continue
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
if k not in self.info:
continue
if k == "loss_gpt_total":
self.losses.append( self.statistics['loss'][-1] )
else:
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
else:
k = "loss"
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
self.losses.append( self.statistics['loss'][-1] )
return data
@ -846,9 +840,17 @@ class TrainingState():
return message
def load_statistics(self, update=False):
if not os.path.isdir(f'{self.dataset_dir}/'):
if not os.path.isdir(self.training_dir):
return
if args.tts_backend == "tortoise":
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
else:
logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ])
if update:
logs = [logs[-1]]
infos = {}
highest_step = self.last_info_check_at
@ -857,28 +859,28 @@ class TrainingState():
self.statistics['lr'] = []
self.it_rates = 0
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
if update:
logs = [logs[-1]]
for log in logs:
with open(log, 'r', encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
if line.find('INFO: Training Metrics:') >= 0:
data = json.loads(line.split("INFO: Training Metrics:")[-1])
if line.find('Training Metrics:') >= 0:
data = json.loads(line.split("Training Metrics:")[-1])
data['mode'] = "training"
elif line.find('INFO: Validation Metrics:') >= 0:
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
elif line.find('Validation Metrics:') >= 0:
data = json.loads(line.split("Validation Metrics:")[-1])
data['mode'] = "validation"
else:
continue
if args.tts_backend == "tortoise":
if "it" not in data:
continue
it = data['it']
else:
if "global_step" not in data:
continue
it = data['global_step']
if update and it <= self.last_info_check_at:
continue
@ -891,20 +893,23 @@ class TrainingState():
if keep <= 0:
return
if not os.path.isdir(self.dataset_dir):
if args.tts_backend == "vall-e":
return
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
if not os.path.isdir(f'{self.training_dir}/finetune/'):
return
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ])
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ])
remove_models = models[:-keep]
remove_states = states[:-keep]
for d in remove_models:
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
path = f'{self.training_dir}/finetune/models/{d}_gpt.pth'
print("Removing", path)
os.remove(path)
for d in remove_states:
path = f'{self.dataset_dir}/training_state/{d}.state'
path = f'{self.training_dir}/finetune/training_state/{d}.state'
print("Removing", path)
os.remove(path)
@ -930,34 +935,10 @@ class TrainingState():
MESSAGE_START = 'Start training from epoch'
MESSAGE_FINSIHED = 'Finished training'
MESSAGE_SAVING = 'INFO: Saving models and training states.'
MESSAGE_SAVING = 'Saving models and training states.'
MESSAGE_METRICS_TRAINING = 'INFO: Training Metrics:'
MESSAGE_METRICS_VALIDATION = 'INFO: Validation Metrics:'
if args.tts_backend == "vall-e":
if self.buffer_json:
self.json_buffer.append(line)
if line.find("{") == 0 and not self.buffer_json:
self.buffer_json = True
self.json_buffer = [line]
if line.find("}") == 0 and self.buffer_json:
try:
data = json.loads("\n".join(self.json_buffer))
except Exception as e:
print(str(e))
if data and 'model.loss' in data:
self.training_started = True
data = self.parse_valle_metrics( data )
print("Training JSON:", data)
else:
data = None
self.buffer_json = None
self.json_buffer = []
MESSAGE_METRICS_TRAINING = 'Training Metrics:'
MESSAGE_METRICS_VALIDATION = 'Validation Metrics:'
if line.find(MESSAGE_FINSIHED) >= 0:
self.killed = True
@ -1469,6 +1450,13 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
result = segments[file]
path = f'{indir}/audio/{file}'
if not os.path.exists(path):
message = f"Missing segment, skipping... {file}"
print(message)
messages.append(message)
errored += 1
continue
text = result['text']
lang = result['lang']
language = result['language']
@ -1479,6 +1467,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
if phonemize:
text = phonemes
normalized = normalizer(text) if normalize else text
if len(text) > 200:
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
print(message)
@ -1511,18 +1501,16 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
os.makedirs(f'{indir}/valle/', exist_ok=True)
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'):
from vall_e.emb.qnt import encode as quantize
# from vall_e.emb.g2p import encode as phonemize
quantized = quantize( waveform, sample_rate ).cpu()
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
print("Quantized:", file)
tokens = tokenize_text(text, config="./models/tokenizers/ipa.json", stringed=False, skip_specials=True)
tokenized = " ".join( tokens )
tokenized = tokenized.replace(" \u02C8", "\u02C8")
tokenized = tokenized.replace(" \u02CC", "\u02CC")
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(tokenized)
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'):
from vall_e.emb.g2p import encode as phonemize
phonemized = phonemize( normalized )
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
training_joined = "\n".join(lines['training'])
validation_joined = "\n".join(lines['validation'])
@ -1786,10 +1774,8 @@ def save_training_settings( **kwargs ):
if args.tts_backend == "tortoise":
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
elif args.tts_backend == "vall-e":
settings['model_name'] = "ar"
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
settings['model_name'] = "nar"
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml')
settings['model_name'] = "[ 'ar-quarter', 'nar-quarter' ]"
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/config.yaml')
messages.append(f"Saved training output")
return settings, messages