cleanups, realigning vall-e training
This commit is contained in:
parent
909325bb5a
commit
f822c87344
|
@ -5,15 +5,13 @@ log_root: ./training/${voice}/finetune/logs/
|
|||
data_dirs: [./training/${voice}/valle/]
|
||||
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
|
||||
|
||||
model: ${model_name}
|
||||
max_phones: 72
|
||||
|
||||
models: '${models}'
|
||||
batch_size: ${batch_size}
|
||||
gradient_accumulation_steps: ${gradient_accumulation_size}
|
||||
eval_batch_size: ${batch_size}
|
||||
|
||||
max_iter: ${iterations}
|
||||
save_ckpt_every: ${save_rate}
|
||||
eval_every: ${validation_rate}
|
||||
|
||||
max_phones: 256
|
||||
|
||||
sampling_temperature: 1.0
|
||||
eval_every: ${validation_rate}
|
152
src/utils.py
152
src/utils.py
|
@ -642,7 +642,6 @@ class TrainingState():
|
|||
self.yaml_config = yaml.safe_load(file)
|
||||
|
||||
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
|
||||
self.dataset_dir = f"{self.training_dir}/finetune/"
|
||||
self.dataset_path = f"{self.training_dir}/train.txt"
|
||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||
self.dataset_size = len(f.readlines())
|
||||
|
@ -690,9 +689,6 @@ class TrainingState():
|
|||
'loss': "",
|
||||
}
|
||||
|
||||
self.buffer_json = None
|
||||
self.json_buffer = []
|
||||
|
||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||
|
||||
if keep_x_past_checkpoints > 0:
|
||||
|
@ -704,18 +700,18 @@ class TrainingState():
|
|||
if args.tts_backend == "vall-e":
|
||||
self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
|
||||
else:
|
||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
||||
self.cmd = [f'train.{"bat" if os.name == "nt" else "sh"}', config_path]
|
||||
|
||||
print("Spawning process: ", " ".join(self.cmd))
|
||||
self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||
|
||||
def parse_metrics(self, data):
|
||||
if isinstance(data, str):
|
||||
if line.find('INFO: Training Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||
if line.find('Training Metrics:') >= 0:
|
||||
data = json.loads(line.split("Training Metrics:")[-1])
|
||||
data['mode'] = "training"
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||
elif line.find('Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||
data['mode'] = "validation"
|
||||
else:
|
||||
return
|
||||
|
@ -755,22 +751,20 @@ class TrainingState():
|
|||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||
|
||||
epoch = self.epoch + (self.step / self.steps)
|
||||
if 'lr' in self.info:
|
||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
||||
|
||||
if args.tts_backend == "tortoise":
|
||||
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
|
||||
if k == "loss_gpt_total":
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
else:
|
||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||
else:
|
||||
k = "loss"
|
||||
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
|
||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||
|
||||
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
||||
if k not in self.info:
|
||||
continue
|
||||
|
||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
self.losses.append( self.statistics['loss'][-1] )
|
||||
|
||||
return data
|
||||
|
||||
|
@ -846,9 +840,17 @@ class TrainingState():
|
|||
return message
|
||||
|
||||
def load_statistics(self, update=False):
|
||||
if not os.path.isdir(f'{self.dataset_dir}/'):
|
||||
if not os.path.isdir(self.training_dir):
|
||||
return
|
||||
|
||||
if args.tts_backend == "tortoise":
|
||||
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
|
||||
else:
|
||||
logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ])
|
||||
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
infos = {}
|
||||
highest_step = self.last_info_check_at
|
||||
|
||||
|
@ -857,28 +859,28 @@ class TrainingState():
|
|||
self.statistics['lr'] = []
|
||||
self.it_rates = 0
|
||||
|
||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
||||
if update:
|
||||
logs = [logs[-1]]
|
||||
|
||||
for log in logs:
|
||||
with open(log, 'r', encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines:
|
||||
if line.find('INFO: Training Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
||||
if line.find('Training Metrics:') >= 0:
|
||||
data = json.loads(line.split("Training Metrics:")[-1])
|
||||
data['mode'] = "training"
|
||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
||||
elif line.find('Validation Metrics:') >= 0:
|
||||
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||
data['mode'] = "validation"
|
||||
else:
|
||||
continue
|
||||
|
||||
if "it" not in data:
|
||||
continue
|
||||
|
||||
it = data['it']
|
||||
if args.tts_backend == "tortoise":
|
||||
if "it" not in data:
|
||||
continue
|
||||
it = data['it']
|
||||
else:
|
||||
if "global_step" not in data:
|
||||
continue
|
||||
it = data['global_step']
|
||||
|
||||
if update and it <= self.last_info_check_at:
|
||||
continue
|
||||
|
@ -891,20 +893,23 @@ class TrainingState():
|
|||
if keep <= 0:
|
||||
return
|
||||
|
||||
if not os.path.isdir(self.dataset_dir):
|
||||
if args.tts_backend == "vall-e":
|
||||
return
|
||||
|
||||
if not os.path.isdir(f'{self.training_dir}/finetune/'):
|
||||
return
|
||||
|
||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ])
|
||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ])
|
||||
remove_models = models[:-keep]
|
||||
remove_states = states[:-keep]
|
||||
|
||||
for d in remove_models:
|
||||
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
||||
path = f'{self.training_dir}/finetune/models/{d}_gpt.pth'
|
||||
print("Removing", path)
|
||||
os.remove(path)
|
||||
for d in remove_states:
|
||||
path = f'{self.dataset_dir}/training_state/{d}.state'
|
||||
path = f'{self.training_dir}/finetune/training_state/{d}.state'
|
||||
print("Removing", path)
|
||||
os.remove(path)
|
||||
|
||||
|
@ -930,34 +935,10 @@ class TrainingState():
|
|||
|
||||
MESSAGE_START = 'Start training from epoch'
|
||||
MESSAGE_FINSIHED = 'Finished training'
|
||||
MESSAGE_SAVING = 'INFO: Saving models and training states.'
|
||||
MESSAGE_SAVING = 'Saving models and training states.'
|
||||
|
||||
MESSAGE_METRICS_TRAINING = 'INFO: Training Metrics:'
|
||||
MESSAGE_METRICS_VALIDATION = 'INFO: Validation Metrics:'
|
||||
|
||||
if args.tts_backend == "vall-e":
|
||||
|
||||
if self.buffer_json:
|
||||
self.json_buffer.append(line)
|
||||
|
||||
if line.find("{") == 0 and not self.buffer_json:
|
||||
self.buffer_json = True
|
||||
self.json_buffer = [line]
|
||||
if line.find("}") == 0 and self.buffer_json:
|
||||
try:
|
||||
data = json.loads("\n".join(self.json_buffer))
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
|
||||
if data and 'model.loss' in data:
|
||||
self.training_started = True
|
||||
data = self.parse_valle_metrics( data )
|
||||
print("Training JSON:", data)
|
||||
else:
|
||||
data = None
|
||||
|
||||
self.buffer_json = None
|
||||
self.json_buffer = []
|
||||
MESSAGE_METRICS_TRAINING = 'Training Metrics:'
|
||||
MESSAGE_METRICS_VALIDATION = 'Validation Metrics:'
|
||||
|
||||
if line.find(MESSAGE_FINSIHED) >= 0:
|
||||
self.killed = True
|
||||
|
@ -1469,6 +1450,13 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
result = segments[file]
|
||||
path = f'{indir}/audio/{file}'
|
||||
|
||||
if not os.path.exists(path):
|
||||
message = f"Missing segment, skipping... {file}"
|
||||
print(message)
|
||||
messages.append(message)
|
||||
errored += 1
|
||||
continue
|
||||
|
||||
text = result['text']
|
||||
lang = result['lang']
|
||||
language = result['language']
|
||||
|
@ -1479,6 +1467,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
if phonemize:
|
||||
text = phonemes
|
||||
|
||||
normalized = normalizer(text) if normalize else text
|
||||
|
||||
if len(text) > 200:
|
||||
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
||||
print(message)
|
||||
|
@ -1511,18 +1501,16 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
|||
|
||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||
|
||||
from vall_e.emb.qnt import encode as quantize
|
||||
# from vall_e.emb.g2p import encode as phonemize
|
||||
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'):
|
||||
from vall_e.emb.qnt import encode as quantize
|
||||
quantized = quantize( waveform, sample_rate ).cpu()
|
||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||
print("Quantized:", file)
|
||||
|
||||
quantized = quantize( waveform, sample_rate ).cpu()
|
||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||
print("Quantized:", file)
|
||||
|
||||
tokens = tokenize_text(text, config="./models/tokenizers/ipa.json", stringed=False, skip_specials=True)
|
||||
tokenized = " ".join( tokens )
|
||||
tokenized = tokenized.replace(" \u02C8", "\u02C8")
|
||||
tokenized = tokenized.replace(" \u02CC", "\u02CC")
|
||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(tokenized)
|
||||
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'):
|
||||
from vall_e.emb.g2p import encode as phonemize
|
||||
phonemized = phonemize( normalized )
|
||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
|
||||
|
||||
training_joined = "\n".join(lines['training'])
|
||||
validation_joined = "\n".join(lines['validation'])
|
||||
|
@ -1786,10 +1774,8 @@ def save_training_settings( **kwargs ):
|
|||
if args.tts_backend == "tortoise":
|
||||
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
|
||||
elif args.tts_backend == "vall-e":
|
||||
settings['model_name'] = "ar"
|
||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
|
||||
settings['model_name'] = "nar"
|
||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml')
|
||||
settings['model_name'] = "[ 'ar-quarter', 'nar-quarter' ]"
|
||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/config.yaml')
|
||||
|
||||
messages.append(f"Saved training output")
|
||||
return settings, messages
|
||||
|
|
Loading…
Reference in New Issue
Block a user