From 6ee5f21ddce4c8128ec51b2ca68c6573e500bb5c Mon Sep 17 00:00:00 2001 From: mrq Date: Tue, 25 Jun 2024 13:40:39 -0500 Subject: [PATCH] oops, needed some fixes --- tortoise_tts/engines/__init__.py | 7 +++- tortoise_tts/engines/base.py | 4 +-- tortoise_tts/export.py | 56 +------------------------------- tortoise_tts/train.py | 2 +- tortoise_tts/utils/trainer.py | 4 --- 5 files changed, 10 insertions(+), 63 deletions(-) diff --git a/tortoise_tts/engines/__init__.py b/tortoise_tts/engines/__init__.py index 75a26c7..d356614 100755 --- a/tortoise_tts/engines/__init__.py +++ b/tortoise_tts/engines/__init__.py @@ -118,10 +118,15 @@ def load_engines(training=True): optimizer = None lr_scheduler = None + checkpoint_path = cfg.ckpt_dir / name / "latest" # automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present load_path = cfg.ckpt_dir / name / "fp32.pth" - if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists(): + # actually use the lora-specific checkpoint if available + if cfg.lora is not None: + checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest" + + if not loads_state_dict and not checkpoint_path.exists() and load_path.exists(): print("Checkpoint missing, but weights found.") loads_state_dict = True diff --git a/tortoise_tts/engines/base.py b/tortoise_tts/engines/base.py index 47e658a..09bc8f9 100755 --- a/tortoise_tts/engines/base.py +++ b/tortoise_tts/engines/base.py @@ -176,7 +176,7 @@ class Engine(): self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step'] self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples'] self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed'] - self.module.load_state_dict(state['module']) + self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading) load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state @@ -344,7 +344,7 @@ class Engines(dict[str, Engine]): state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path ) torch.save(state_dict, save_path) - print(f"Exported {name} to {outpath}") + print(f"Exported {name} to {save_path}") def save_checkpoint(self, tag=None): if not tag: diff --git a/tortoise_tts/export.py b/tortoise_tts/export.py index 27329a0..9ee285b 100755 --- a/tortoise_tts/export.py +++ b/tortoise_tts/export.py @@ -8,54 +8,6 @@ from .engines import load_engines from .config import cfg from .models.lora import lora_get_state_dict -# stitches embeddings into one embedding & classifier => lm_head -def convert_to_hf( state_dict, config = None, save_path = None ): - n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1 - token_dim = 1024 - embedding = torch.nn.Embedding(n_tokens, token_dim) - embedding.weight.requires_grad = False - - def move_value(k): - v = state_dict['module'][k] - del state_dict['module'][k] - return v - - separator = move_value('sep') - out_proj = move_value('classifier.weight') - text_emb = move_value('text_emb.weight') - langs_emb = move_value('langs_emb.weight') - tasks_emb = move_value('tasks_emb.weight') - tones_emb = move_value('tones_emb.weight') - - proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ] - resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ] - - proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ] - resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ] - - - start = 0 - for i in range(256): - embedding.weight[start + i] = text_emb[i] - - start = 256 - for layer in range(8): - for i in range(1024): - offset = start + 1024 * layer - embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer] - - start = 256 + 1024 * 8 - for layer in range(8): - for i in range(1024): - offset = start + 1024 * layer - embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer] - - state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict() - state_dict['module']['lm_head.weight'] = out_proj - del state_dict['module']['classifier.bias'] - - return state_dict - def extract_lora( state_dict, config = None, save_path = None ): lora = state_dict["lora"] if "lora" in state_dict else None # should always be included, but just in case @@ -79,7 +31,6 @@ def extract_lora( state_dict, config = None, save_path = None ): def main(): parser = argparse.ArgumentParser("Save trained model to path.") parser.add_argument("--module-only", action='store_true') - parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style parser.add_argument("--lora", action='store_true', default=None) # exports LoRA args, unknown = parser.parse_known_args() @@ -87,14 +38,9 @@ def main(): cfg.trainer.load_module_only = True callback = None - if args.hf: - callback = convert_to_hf - elif args.lora: + if args.lora: callback = extract_lora - if args.hf and args.lora: - raise Exception("Requesting more than one callback") - engines = load_engines() engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback) diff --git a/tortoise_tts/train.py b/tortoise_tts/train.py index c796c86..eb45ac1 100755 --- a/tortoise_tts/train.py +++ b/tortoise_tts/train.py @@ -42,7 +42,7 @@ def train_feeder(engine, batch): text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True) text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32) - mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token ) + mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token ) wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32) engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths) diff --git a/tortoise_tts/utils/trainer.py b/tortoise_tts/utils/trainer.py index 947d0e6..d190125 100755 --- a/tortoise_tts/utils/trainer.py +++ b/tortoise_tts/utils/trainer.py @@ -131,10 +131,6 @@ def train( _logger.info(cfg) """ - # Setup global engines - global _engines - _engines = engines - events = [] eval_fn = global_leader_only(eval_fn)