cleanups, realigning vall-e training
This commit is contained in:
parent
909325bb5a
commit
f822c87344
|
@ -5,15 +5,13 @@ log_root: ./training/${voice}/finetune/logs/
|
||||||
data_dirs: [./training/${voice}/valle/]
|
data_dirs: [./training/${voice}/valle/]
|
||||||
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
|
spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]"
|
||||||
|
|
||||||
model: ${model_name}
|
max_phones: 72
|
||||||
|
|
||||||
|
models: '${models}'
|
||||||
batch_size: ${batch_size}
|
batch_size: ${batch_size}
|
||||||
gradient_accumulation_steps: ${gradient_accumulation_size}
|
gradient_accumulation_steps: ${gradient_accumulation_size}
|
||||||
eval_batch_size: ${batch_size}
|
eval_batch_size: ${batch_size}
|
||||||
|
|
||||||
max_iter: ${iterations}
|
max_iter: ${iterations}
|
||||||
save_ckpt_every: ${save_rate}
|
save_ckpt_every: ${save_rate}
|
||||||
eval_every: ${validation_rate}
|
eval_every: ${validation_rate}
|
||||||
|
|
||||||
max_phones: 256
|
|
||||||
|
|
||||||
sampling_temperature: 1.0
|
|
152
src/utils.py
152
src/utils.py
|
@ -642,7 +642,6 @@ class TrainingState():
|
||||||
self.yaml_config = yaml.safe_load(file)
|
self.yaml_config = yaml.safe_load(file)
|
||||||
|
|
||||||
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
|
self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8"))
|
||||||
self.dataset_dir = f"{self.training_dir}/finetune/"
|
|
||||||
self.dataset_path = f"{self.training_dir}/train.txt"
|
self.dataset_path = f"{self.training_dir}/train.txt"
|
||||||
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
with open(self.dataset_path, 'r', encoding="utf-8") as f:
|
||||||
self.dataset_size = len(f.readlines())
|
self.dataset_size = len(f.readlines())
|
||||||
|
@ -690,9 +689,6 @@ class TrainingState():
|
||||||
'loss': "",
|
'loss': "",
|
||||||
}
|
}
|
||||||
|
|
||||||
self.buffer_json = None
|
|
||||||
self.json_buffer = []
|
|
||||||
|
|
||||||
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
self.loss_milestones = [ 1.0, 0.15, 0.05 ]
|
||||||
|
|
||||||
if keep_x_past_checkpoints > 0:
|
if keep_x_past_checkpoints > 0:
|
||||||
|
@ -704,18 +700,18 @@ class TrainingState():
|
||||||
if args.tts_backend == "vall-e":
|
if args.tts_backend == "vall-e":
|
||||||
self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
|
self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"']
|
||||||
else:
|
else:
|
||||||
self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path]
|
self.cmd = [f'train.{"bat" if os.name == "nt" else "sh"}', config_path]
|
||||||
|
|
||||||
print("Spawning process: ", " ".join(self.cmd))
|
print("Spawning process: ", " ".join(self.cmd))
|
||||||
self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
|
||||||
|
|
||||||
def parse_metrics(self, data):
|
def parse_metrics(self, data):
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
if line.find('INFO: Training Metrics:') >= 0:
|
if line.find('Training Metrics:') >= 0:
|
||||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
data = json.loads(line.split("Training Metrics:")[-1])
|
||||||
data['mode'] = "training"
|
data['mode'] = "training"
|
||||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
elif line.find('Validation Metrics:') >= 0:
|
||||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||||
data['mode'] = "validation"
|
data['mode'] = "validation"
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
@ -755,22 +751,20 @@ class TrainingState():
|
||||||
self.metrics['step'] = ", ".join(self.metrics['step'])
|
self.metrics['step'] = ", ".join(self.metrics['step'])
|
||||||
|
|
||||||
epoch = self.epoch + (self.step / self.steps)
|
epoch = self.epoch + (self.step / self.steps)
|
||||||
if 'lr' in self.info:
|
|
||||||
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'})
|
|
||||||
|
|
||||||
if args.tts_backend == "tortoise":
|
|
||||||
for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']:
|
|
||||||
if k not in self.info:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if k == "loss_gpt_total":
|
for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']:
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
if k not in self.info:
|
||||||
else:
|
continue
|
||||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
|
||||||
else:
|
self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k})
|
||||||
k = "loss"
|
|
||||||
|
for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']:
|
||||||
|
if k not in self.info:
|
||||||
|
continue
|
||||||
|
|
||||||
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' })
|
||||||
self.losses.append( self.statistics['loss'][-1] )
|
|
||||||
|
self.losses.append( self.statistics['loss'][-1] )
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -846,9 +840,17 @@ class TrainingState():
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def load_statistics(self, update=False):
|
def load_statistics(self, update=False):
|
||||||
if not os.path.isdir(f'{self.dataset_dir}/'):
|
if not os.path.isdir(self.training_dir):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if args.tts_backend == "tortoise":
|
||||||
|
logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ])
|
||||||
|
else:
|
||||||
|
logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ])
|
||||||
|
|
||||||
|
if update:
|
||||||
|
logs = [logs[-1]]
|
||||||
|
|
||||||
infos = {}
|
infos = {}
|
||||||
highest_step = self.last_info_check_at
|
highest_step = self.last_info_check_at
|
||||||
|
|
||||||
|
@ -857,28 +859,28 @@ class TrainingState():
|
||||||
self.statistics['lr'] = []
|
self.statistics['lr'] = []
|
||||||
self.it_rates = 0
|
self.it_rates = 0
|
||||||
|
|
||||||
logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ])
|
|
||||||
if update:
|
|
||||||
logs = [logs[-1]]
|
|
||||||
|
|
||||||
for log in logs:
|
for log in logs:
|
||||||
with open(log, 'r', encoding="utf-8") as f:
|
with open(log, 'r', encoding="utf-8") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.find('INFO: Training Metrics:') >= 0:
|
if line.find('Training Metrics:') >= 0:
|
||||||
data = json.loads(line.split("INFO: Training Metrics:")[-1])
|
data = json.loads(line.split("Training Metrics:")[-1])
|
||||||
data['mode'] = "training"
|
data['mode'] = "training"
|
||||||
elif line.find('INFO: Validation Metrics:') >= 0:
|
elif line.find('Validation Metrics:') >= 0:
|
||||||
data = json.loads(line.split("INFO: Validation Metrics:")[-1])
|
data = json.loads(line.split("Validation Metrics:")[-1])
|
||||||
data['mode'] = "validation"
|
data['mode'] = "validation"
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "it" not in data:
|
if args.tts_backend == "tortoise":
|
||||||
continue
|
if "it" not in data:
|
||||||
|
continue
|
||||||
it = data['it']
|
it = data['it']
|
||||||
|
else:
|
||||||
|
if "global_step" not in data:
|
||||||
|
continue
|
||||||
|
it = data['global_step']
|
||||||
|
|
||||||
if update and it <= self.last_info_check_at:
|
if update and it <= self.last_info_check_at:
|
||||||
continue
|
continue
|
||||||
|
@ -891,20 +893,23 @@ class TrainingState():
|
||||||
if keep <= 0:
|
if keep <= 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not os.path.isdir(self.dataset_dir):
|
if args.tts_backend == "vall-e":
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.isdir(f'{self.training_dir}/finetune/'):
|
||||||
return
|
return
|
||||||
|
|
||||||
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ])
|
models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ])
|
||||||
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ])
|
states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ])
|
||||||
remove_models = models[:-keep]
|
remove_models = models[:-keep]
|
||||||
remove_states = states[:-keep]
|
remove_states = states[:-keep]
|
||||||
|
|
||||||
for d in remove_models:
|
for d in remove_models:
|
||||||
path = f'{self.dataset_dir}/models/{d}_gpt.pth'
|
path = f'{self.training_dir}/finetune/models/{d}_gpt.pth'
|
||||||
print("Removing", path)
|
print("Removing", path)
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
for d in remove_states:
|
for d in remove_states:
|
||||||
path = f'{self.dataset_dir}/training_state/{d}.state'
|
path = f'{self.training_dir}/finetune/training_state/{d}.state'
|
||||||
print("Removing", path)
|
print("Removing", path)
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
|
@ -930,34 +935,10 @@ class TrainingState():
|
||||||
|
|
||||||
MESSAGE_START = 'Start training from epoch'
|
MESSAGE_START = 'Start training from epoch'
|
||||||
MESSAGE_FINSIHED = 'Finished training'
|
MESSAGE_FINSIHED = 'Finished training'
|
||||||
MESSAGE_SAVING = 'INFO: Saving models and training states.'
|
MESSAGE_SAVING = 'Saving models and training states.'
|
||||||
|
|
||||||
MESSAGE_METRICS_TRAINING = 'INFO: Training Metrics:'
|
MESSAGE_METRICS_TRAINING = 'Training Metrics:'
|
||||||
MESSAGE_METRICS_VALIDATION = 'INFO: Validation Metrics:'
|
MESSAGE_METRICS_VALIDATION = 'Validation Metrics:'
|
||||||
|
|
||||||
if args.tts_backend == "vall-e":
|
|
||||||
|
|
||||||
if self.buffer_json:
|
|
||||||
self.json_buffer.append(line)
|
|
||||||
|
|
||||||
if line.find("{") == 0 and not self.buffer_json:
|
|
||||||
self.buffer_json = True
|
|
||||||
self.json_buffer = [line]
|
|
||||||
if line.find("}") == 0 and self.buffer_json:
|
|
||||||
try:
|
|
||||||
data = json.loads("\n".join(self.json_buffer))
|
|
||||||
except Exception as e:
|
|
||||||
print(str(e))
|
|
||||||
|
|
||||||
if data and 'model.loss' in data:
|
|
||||||
self.training_started = True
|
|
||||||
data = self.parse_valle_metrics( data )
|
|
||||||
print("Training JSON:", data)
|
|
||||||
else:
|
|
||||||
data = None
|
|
||||||
|
|
||||||
self.buffer_json = None
|
|
||||||
self.json_buffer = []
|
|
||||||
|
|
||||||
if line.find(MESSAGE_FINSIHED) >= 0:
|
if line.find(MESSAGE_FINSIHED) >= 0:
|
||||||
self.killed = True
|
self.killed = True
|
||||||
|
@ -1469,6 +1450,13 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
result = segments[file]
|
result = segments[file]
|
||||||
path = f'{indir}/audio/{file}'
|
path = f'{indir}/audio/{file}'
|
||||||
|
|
||||||
|
if not os.path.exists(path):
|
||||||
|
message = f"Missing segment, skipping... {file}"
|
||||||
|
print(message)
|
||||||
|
messages.append(message)
|
||||||
|
errored += 1
|
||||||
|
continue
|
||||||
|
|
||||||
text = result['text']
|
text = result['text']
|
||||||
lang = result['lang']
|
lang = result['lang']
|
||||||
language = result['language']
|
language = result['language']
|
||||||
|
@ -1479,6 +1467,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
if phonemize:
|
if phonemize:
|
||||||
text = phonemes
|
text = phonemes
|
||||||
|
|
||||||
|
normalized = normalizer(text) if normalize else text
|
||||||
|
|
||||||
if len(text) > 200:
|
if len(text) > 200:
|
||||||
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
message = f"Text length too long (200 < {len(text)}), skipping... {file}"
|
||||||
print(message)
|
print(message)
|
||||||
|
@ -1511,18 +1501,16 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p
|
||||||
|
|
||||||
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
os.makedirs(f'{indir}/valle/', exist_ok=True)
|
||||||
|
|
||||||
from vall_e.emb.qnt import encode as quantize
|
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'):
|
||||||
# from vall_e.emb.g2p import encode as phonemize
|
from vall_e.emb.qnt import encode as quantize
|
||||||
|
quantized = quantize( waveform, sample_rate ).cpu()
|
||||||
|
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
||||||
|
print("Quantized:", file)
|
||||||
|
|
||||||
quantized = quantize( waveform, sample_rate ).cpu()
|
if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'):
|
||||||
torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}')
|
from vall_e.emb.g2p import encode as phonemize
|
||||||
print("Quantized:", file)
|
phonemized = phonemize( normalized )
|
||||||
|
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized))
|
||||||
tokens = tokenize_text(text, config="./models/tokenizers/ipa.json", stringed=False, skip_specials=True)
|
|
||||||
tokenized = " ".join( tokens )
|
|
||||||
tokenized = tokenized.replace(" \u02C8", "\u02C8")
|
|
||||||
tokenized = tokenized.replace(" \u02CC", "\u02CC")
|
|
||||||
open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(tokenized)
|
|
||||||
|
|
||||||
training_joined = "\n".join(lines['training'])
|
training_joined = "\n".join(lines['training'])
|
||||||
validation_joined = "\n".join(lines['validation'])
|
validation_joined = "\n".join(lines['validation'])
|
||||||
|
@ -1786,10 +1774,8 @@ def save_training_settings( **kwargs ):
|
||||||
if args.tts_backend == "tortoise":
|
if args.tts_backend == "tortoise":
|
||||||
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
|
use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml')
|
||||||
elif args.tts_backend == "vall-e":
|
elif args.tts_backend == "vall-e":
|
||||||
settings['model_name'] = "ar"
|
settings['model_name'] = "[ 'ar-quarter', 'nar-quarter' ]"
|
||||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml')
|
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/config.yaml')
|
||||||
settings['model_name'] = "nar"
|
|
||||||
use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml')
|
|
||||||
|
|
||||||
messages.append(f"Saved training output")
|
messages.append(f"Saved training output")
|
||||||
return settings, messages
|
return settings, messages
|
||||||
|
|
Loading…
Reference in New Issue
Block a user