From 4585824cd3c4f766788cbf1fe79df46515029f17 Mon Sep 17 00:00:00 2001 From: mrq Date: Wed, 23 Aug 2023 16:43:03 -0500 Subject: [PATCH] tweaks, including exporting on save/quit --- vall_e/config.py | 4 ++++ vall_e/data.py | 6 +++--- vall_e/emb/qnt.py | 12 ++++++++++++ vall_e/inference.py | 16 ++++++++++++---- vall_e/utils/trainer.py | 11 +++++++++++ 5 files changed, 42 insertions(+), 7 deletions(-) diff --git a/vall_e/config.py b/vall_e/config.py index da95ce6..4a48d55 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -372,6 +372,10 @@ class Trainer: save_on_oom: bool = True save_on_quit: bool = True + + export_on_save: bool = False + export_on_quit: bool = False + save_frequency: int = 100 keep_last_checkpoints: int = 0 diff --git a/vall_e/data.py b/vall_e/data.py index 20498a5..84684c5 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -10,7 +10,7 @@ import random import torch from .config import cfg -from .emb.qnt import trim_random, repeat_extend_audio, merge_audio, decode_to_file +from .emb.qnt import trim, trim_random, repeat_extend_audio, merge_audio, decode_to_file from collections import defaultdict from functools import cache, cached_property @@ -253,7 +253,7 @@ class Dataset(_Dataset): qnt = _load_quants(path) if cfg.dataset.prompt_duration > 0 and trim_length < qnt.shape[0]: - qnt = trim_random( qnt, trim_length ) + qnt = trim( qnt, trim_length ) prom_list.append(qnt) prom_length += qnt.shape[0] @@ -264,7 +264,7 @@ class Dataset(_Dataset): prom = torch.cat(prom_list) if cfg.dataset.prompt_duration > 0 and trim_length < prom.shape[0]: - prom = trim_random( prom, trim_length ) + prom = trim( prom, trim_length ) return prom diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 7985b85..14f0fe8 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -180,7 +180,19 @@ def encode_from_file(path, device="cuda"): # Helper Functions +# trims from the start, up to `target` +def trim( qnt, target ): + length = max( qnt.shape[0], qnt.shape[1] ) + start = 0 + end = start + target + if end >= length: + start = length - target + end = length + + return qnt[start:end] if qnt.shape[0] > qnt.shape[1] else qnt[:, start:end] + # trims a random piece of audio, up to `target` +# to-do: try and align to EnCodec window def trim_random( qnt, target ): length = max( qnt.shape[0], qnt.shape[1] ) start = int(length * random.random()) diff --git a/vall_e/inference.py b/vall_e/inference.py index 3b5d423..dc754cc 100755 --- a/vall_e/inference.py +++ b/vall_e/inference.py @@ -7,7 +7,7 @@ from einops import rearrange from pathlib import Path from .emb import g2p, qnt -from .emb.qnt import trim_random +from .emb.qnt import trim, trim_random from .utils import to_device from .config import cfg @@ -25,12 +25,14 @@ class TTS(): if config: cfg.load_yaml( config ) + cfg.dataset.use_hdf5 = False # could use cfg.load_hdf5(), but why would it ever need to be loaded for inferencing try: cfg.format() except Exception as e: pass + self.symmap = None if ar_ckpt and nar_ckpt: self.ar_ckpt = ar_ckpt self.nar_ckpt = nar_ckpt @@ -40,6 +42,8 @@ class TTS(): if name.startswith("ar"): self.ar = model state = torch.load(self.ar_ckpt) + if "symmap" in state: + self.symmap = state['symmap'] if "module" in state: state = state['module'] self.ar.load_state_dict(state) @@ -47,6 +51,8 @@ class TTS(): elif name.startswith("nar"): self.nar = model state = torch.load(self.nar_ckpt) + if "symmap" in state: + self.symmap = state['symmap'] if "module" in state: state = state['module'] self.nar.load_state_dict(state) @@ -54,7 +60,9 @@ class TTS(): else: self.load_models() - self.symmap = get_phone_symmap() + if self.symmap is None: + self.symmap = get_phone_symmap() + self.ar.eval() self.nar.eval() @@ -78,7 +86,7 @@ class TTS(): phones = [ " " if not p else p for p in content ] return torch.tensor([ 1 ] + [*map(self.symmap.get, phones)] + [ 2 ]) - def encode_audio( self, paths, trim=True ): + def encode_audio( self, paths, should_trim=True ): # already a tensor, return it if isinstance( paths, Tensor ): return paths @@ -90,7 +98,7 @@ class TTS(): # merge inputs res = torch.cat([qnt.encode_from_file( path )[0].t().to(torch.int16) for path in paths]) - if trim: + if should_trim: res = trim_random( res, int( 75 * cfg.dataset.prompt_duration ) ) return res diff --git a/vall_e/utils/trainer.py b/vall_e/utils/trainer.py index 4b795eb..6f2ea26 100755 --- a/vall_e/utils/trainer.py +++ b/vall_e/utils/trainer.py @@ -324,14 +324,24 @@ def train( save_ckpt_every = cfg.trainer.save_frequency or cfg.evaluation.frequency saving_commands = ["save"] + export_commands = ["export"] if cfg.trainer.save_on_quit: saving_commands.append("quit") + if cfg.trainer.export_on_quit: + export_commands.append("quit") + + if cfg.trainer.export_on_save: + export_commands.append("save") + if engines.global_step != last_save_step: if engines.global_step % save_ckpt_every == 0 or command in saving_commands: engines.save_checkpoint() last_save_step = engines.global_step + + if command in export_commands and is_global_leader(): + engines.export(userdata={"symmap": get_phone_symmap()}) if engines.global_step != last_eval_step: if engines.global_step % cfg.evaluation.frequency == 0 or command in ["eval"]: @@ -343,4 +353,5 @@ def train( last_eval_step = engines.global_step if command in ["quit"]: + if cfg.export_on_quit: return