From f822c87344d2ce01bd077b48d9db42740c51ea04 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 22 Mar 2023 17:47:23 +0000 Subject: [PATCH] cleanups, realigning vall-e training --- models/.template.valle.yaml | 10 +-- src/utils.py | 152 ++++++++++++++++-------------------- 2 files changed, 73 insertions(+), 89 deletions(-) diff --git a/models/.template.valle.yaml b/models/.template.valle.yaml index d3458f2..d6982d2 100755 --- a/models/.template.valle.yaml +++ b/models/.template.valle.yaml @@ -5,15 +5,13 @@ log_root: ./training/${voice}/finetune/logs/ data_dirs: [./training/${voice}/valle/] spkr_name_getter: "lambda p: p.parts[-3]" # "lambda p: p.parts[-1].split('-')[0]" -model: ${model_name} +max_phones: 72 + +models: '${models}' batch_size: ${batch_size} gradient_accumulation_steps: ${gradient_accumulation_size} eval_batch_size: ${batch_size} max_iter: ${iterations} save_ckpt_every: ${save_rate} -eval_every: ${validation_rate} - -max_phones: 256 - -sampling_temperature: 1.0 \ No newline at end of file +eval_every: ${validation_rate} \ No newline at end of file diff --git a/src/utils.py b/src/utils.py index bb4058e..b39aafa 100755 --- a/src/utils.py +++ b/src/utils.py @@ -642,7 +642,6 @@ class TrainingState(): self.yaml_config = yaml.safe_load(file) self.json_config = json.load(open(f"{self.training_dir}/train.json", 'r', encoding="utf-8")) - self.dataset_dir = f"{self.training_dir}/finetune/" self.dataset_path = f"{self.training_dir}/train.txt" with open(self.dataset_path, 'r', encoding="utf-8") as f: self.dataset_size = len(f.readlines()) @@ -690,9 +689,6 @@ class TrainingState(): 'loss': "", } - self.buffer_json = None - self.json_buffer = [] - self.loss_milestones = [ 1.0, 0.15, 0.05 ] if keep_x_past_checkpoints > 0: @@ -704,18 +700,18 @@ class TrainingState(): if args.tts_backend == "vall-e": self.cmd = ['deepspeed', f'--num_gpus={gpus}', '--module', 'vall_e.train', f'yaml="{config_path}"'] else: - self.cmd = ['train.bat', config_path] if os.name == "nt" else ['./train.sh', config_path] + self.cmd = [f'train.{"bat" if os.name == "nt" else "sh"}', config_path] print("Spawning process: ", " ".join(self.cmd)) self.process = subprocess.Popen(self.cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True) def parse_metrics(self, data): if isinstance(data, str): - if line.find('INFO: Training Metrics:') >= 0: - data = json.loads(line.split("INFO: Training Metrics:")[-1]) + if line.find('Training Metrics:') >= 0: + data = json.loads(line.split("Training Metrics:")[-1]) data['mode'] = "training" - elif line.find('INFO: Validation Metrics:') >= 0: - data = json.loads(line.split("INFO: Validation Metrics:")[-1]) + elif line.find('Validation Metrics:') >= 0: + data = json.loads(line.split("Validation Metrics:")[-1]) data['mode'] = "validation" else: return @@ -755,22 +751,20 @@ class TrainingState(): self.metrics['step'] = ", ".join(self.metrics['step']) epoch = self.epoch + (self.step / self.steps) - if 'lr' in self.info: - self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info['lr'], 'type': 'learning_rate'}) - - if args.tts_backend == "tortoise": - for k in ['loss_text_ce', 'loss_mel_ce', 'loss_gpt_total']: - if k not in self.info: - continue - if k == "loss_gpt_total": - self.losses.append( self.statistics['loss'][-1] ) - else: - self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) - else: - k = "loss" + for k in ['lr'] if args.tts_backend == "tortoise" else ['ar.lr', 'nar.lr', 'aar-half.lr', 'nar-half.lr', 'ar-quarter.lr', 'nar-quarter.lr']: + if k not in self.info: + continue + + self.statistics['lr'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': k}) + + for k in ['loss_text_ce', 'loss_mel_ce'] if args.tts_backend == "tortoise" else ['ar.loss', 'nar.loss', 'aar-half.loss', 'nar-half.loss', 'ar-quarter.loss', 'nar-quarter.loss']: + if k not in self.info: + continue + self.statistics['loss'].append({'epoch': epoch, 'it': self.it, 'value': self.info[k], 'type': f'{"val_" if data["mode"] == "validation" else ""}{k}' }) - self.losses.append( self.statistics['loss'][-1] ) + + self.losses.append( self.statistics['loss'][-1] ) return data @@ -846,9 +840,17 @@ class TrainingState(): return message def load_statistics(self, update=False): - if not os.path.isdir(f'{self.dataset_dir}/'): + if not os.path.isdir(self.training_dir): return + if args.tts_backend == "tortoise": + logs = sorted([f'{self.training_dir}/finetune/{d}' for d in os.listdir(f'{self.training_dir}/finetune/') if d[-4:] == ".log" ]) + else: + logs = sorted([f'{self.training_dir}/logs/{d}/log.txt' for d in os.listdir(f'{self.training_dir}/logs/') ]) + + if update: + logs = [logs[-1]] + infos = {} highest_step = self.last_info_check_at @@ -857,28 +859,28 @@ class TrainingState(): self.statistics['lr'] = [] self.it_rates = 0 - logs = sorted([f'{self.dataset_dir}/{d}' for d in os.listdir(self.dataset_dir) if d[-4:] == ".log" ]) - if update: - logs = [logs[-1]] - for log in logs: with open(log, 'r', encoding="utf-8") as f: lines = f.readlines() for line in lines: - if line.find('INFO: Training Metrics:') >= 0: - data = json.loads(line.split("INFO: Training Metrics:")[-1]) + if line.find('Training Metrics:') >= 0: + data = json.loads(line.split("Training Metrics:")[-1]) data['mode'] = "training" - elif line.find('INFO: Validation Metrics:') >= 0: - data = json.loads(line.split("INFO: Validation Metrics:")[-1]) + elif line.find('Validation Metrics:') >= 0: + data = json.loads(line.split("Validation Metrics:")[-1]) data['mode'] = "validation" else: continue - if "it" not in data: - continue - - it = data['it'] + if args.tts_backend == "tortoise": + if "it" not in data: + continue + it = data['it'] + else: + if "global_step" not in data: + continue + it = data['global_step'] if update and it <= self.last_info_check_at: continue @@ -891,20 +893,23 @@ class TrainingState(): if keep <= 0: return - if not os.path.isdir(self.dataset_dir): + if args.tts_backend == "vall-e": + return + + if not os.path.isdir(f'{self.training_dir}/finetune/'): return - models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.dataset_dir}/models/') if d[-8:] == "_gpt.pth" ]) - states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.dataset_dir}/training_state/') if d[-6:] == ".state" ]) + models = sorted([ int(d[:-8]) for d in os.listdir(f'{self.training_dir}/finetune/models/') if d[-8:] == "_gpt.pth" ]) + states = sorted([ int(d[:-6]) for d in os.listdir(f'{self.training_dir}/finetune/training_state/') if d[-6:] == ".state" ]) remove_models = models[:-keep] remove_states = states[:-keep] for d in remove_models: - path = f'{self.dataset_dir}/models/{d}_gpt.pth' + path = f'{self.training_dir}/finetune/models/{d}_gpt.pth' print("Removing", path) os.remove(path) for d in remove_states: - path = f'{self.dataset_dir}/training_state/{d}.state' + path = f'{self.training_dir}/finetune/training_state/{d}.state' print("Removing", path) os.remove(path) @@ -930,34 +935,10 @@ class TrainingState(): MESSAGE_START = 'Start training from epoch' MESSAGE_FINSIHED = 'Finished training' - MESSAGE_SAVING = 'INFO: Saving models and training states.' + MESSAGE_SAVING = 'Saving models and training states.' - MESSAGE_METRICS_TRAINING = 'INFO: Training Metrics:' - MESSAGE_METRICS_VALIDATION = 'INFO: Validation Metrics:' - - if args.tts_backend == "vall-e": - - if self.buffer_json: - self.json_buffer.append(line) - - if line.find("{") == 0 and not self.buffer_json: - self.buffer_json = True - self.json_buffer = [line] - if line.find("}") == 0 and self.buffer_json: - try: - data = json.loads("\n".join(self.json_buffer)) - except Exception as e: - print(str(e)) - - if data and 'model.loss' in data: - self.training_started = True - data = self.parse_valle_metrics( data ) - print("Training JSON:", data) - else: - data = None - - self.buffer_json = None - self.json_buffer = [] + MESSAGE_METRICS_TRAINING = 'Training Metrics:' + MESSAGE_METRICS_VALIDATION = 'Validation Metrics:' if line.find(MESSAGE_FINSIHED) >= 0: self.killed = True @@ -1469,6 +1450,13 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p result = segments[file] path = f'{indir}/audio/{file}' + if not os.path.exists(path): + message = f"Missing segment, skipping... {file}" + print(message) + messages.append(message) + errored += 1 + continue + text = result['text'] lang = result['lang'] language = result['language'] @@ -1479,6 +1467,8 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p if phonemize: text = phonemes + normalized = normalizer(text) if normalize else text + if len(text) > 200: message = f"Text length too long (200 < {len(text)}), skipping... {file}" print(message) @@ -1511,18 +1501,16 @@ def prepare_dataset( voice, use_segments=False, text_length=0, audio_length=0, p os.makedirs(f'{indir}/valle/', exist_ok=True) - from vall_e.emb.qnt import encode as quantize - # from vall_e.emb.g2p import encode as phonemize + if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".qnt.pt")}'): + from vall_e.emb.qnt import encode as quantize + quantized = quantize( waveform, sample_rate ).cpu() + torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') + print("Quantized:", file) - quantized = quantize( waveform, sample_rate ).cpu() - torch.save(quantized, f'{indir}/valle/{file.replace(".wav",".qnt.pt")}') - print("Quantized:", file) - - tokens = tokenize_text(text, config="./models/tokenizers/ipa.json", stringed=False, skip_specials=True) - tokenized = " ".join( tokens ) - tokenized = tokenized.replace(" \u02C8", "\u02C8") - tokenized = tokenized.replace(" \u02CC", "\u02CC") - open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(tokenized) + if not os.path.exists(f'{indir}/valle/{file.replace(".wav",".phn.txt")}'): + from vall_e.emb.g2p import encode as phonemize + phonemized = phonemize( normalized ) + open(f'{indir}/valle/{file.replace(".wav",".phn.txt")}', 'w', encoding='utf-8').write(" ".join(phonemized)) training_joined = "\n".join(lines['training']) validation_joined = "\n".join(lines['validation']) @@ -1786,10 +1774,8 @@ def save_training_settings( **kwargs ): if args.tts_backend == "tortoise": use_template(f'./models/.template.dlas.yaml', f'./training/{settings["voice"]}/train.yaml') elif args.tts_backend == "vall-e": - settings['model_name'] = "ar" - use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/ar.yaml') - settings['model_name'] = "nar" - use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/nar.yaml') + settings['model_name'] = "[ 'ar-quarter', 'nar-quarter' ]" + use_template(f'./models/.template.valle.yaml', f'./training/{settings["voice"]}/config.yaml') messages.append(f"Saved training output") return settings, messages