oops, needed some fixes

This commit is contained in:
mrq 2024-06-25 13:40:39 -05:00
parent 286681c87c
commit 6ee5f21ddc
5 changed files with 10 additions and 63 deletions

View File

@ -118,10 +118,15 @@ def load_engines(training=True):
optimizer = None
lr_scheduler = None
checkpoint_path = cfg.ckpt_dir / name / "latest"
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
load_path = cfg.ckpt_dir / name / "fp32.pth"
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
# actually use the lora-specific checkpoint if available
if cfg.lora is not None:
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
print("Checkpoint missing, but weights found.")
loads_state_dict = True

View File

@ -176,7 +176,7 @@ class Engine():
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module'])
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
@ -344,7 +344,7 @@ class Engines(dict[str, Engine]):
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
torch.save(state_dict, save_path)
print(f"Exported {name} to {outpath}")
print(f"Exported {name} to {save_path}")
def save_checkpoint(self, tag=None):
if not tag:

View File

@ -8,54 +8,6 @@ from .engines import load_engines
from .config import cfg
from .models.lora import lora_get_state_dict
# stitches embeddings into one embedding & classifier => lm_head
def convert_to_hf( state_dict, config = None, save_path = None ):
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
token_dim = 1024
embedding = torch.nn.Embedding(n_tokens, token_dim)
embedding.weight.requires_grad = False
def move_value(k):
v = state_dict['module'][k]
del state_dict['module'][k]
return v
separator = move_value('sep')
out_proj = move_value('classifier.weight')
text_emb = move_value('text_emb.weight')
langs_emb = move_value('langs_emb.weight')
tasks_emb = move_value('tasks_emb.weight')
tones_emb = move_value('tones_emb.weight')
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
start = 0
for i in range(256):
embedding.weight[start + i] = text_emb[i]
start = 256
for layer in range(8):
for i in range(1024):
offset = start + 1024 * layer
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
start = 256 + 1024 * 8
for layer in range(8):
for i in range(1024):
offset = start + 1024 * layer
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
state_dict['module']['lm_head.weight'] = out_proj
del state_dict['module']['classifier.bias']
return state_dict
def extract_lora( state_dict, config = None, save_path = None ):
lora = state_dict["lora"] if "lora" in state_dict else None
# should always be included, but just in case
@ -79,7 +31,6 @@ def extract_lora( state_dict, config = None, save_path = None ):
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
args, unknown = parser.parse_known_args()
@ -87,14 +38,9 @@ def main():
cfg.trainer.load_module_only = True
callback = None
if args.hf:
callback = convert_to_hf
elif args.lora:
if args.lora:
callback = extract_lora
if args.hf and args.lora:
raise Exception("Requesting more than one callback")
engines = load_engines()
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)

View File

@ -42,7 +42,7 @@ def train_feeder(engine, batch):
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token )
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)

View File

@ -131,10 +131,6 @@ def train(
_logger.info(cfg)
"""
# Setup global engines
global _engines
_engines = engines
events = []
eval_fn = global_leader_only(eval_fn)