oops, needed some fixes
This commit is contained in:
parent
286681c87c
commit
6ee5f21ddc
|
@ -118,10 +118,15 @@ def load_engines(training=True):
|
|||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
|
||||
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
||||
# actually use the lora-specific checkpoint if available
|
||||
if cfg.lora is not None:
|
||||
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
||||
|
||||
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||
print("Checkpoint missing, but weights found.")
|
||||
loads_state_dict = True
|
||||
|
||||
|
|
|
@ -176,7 +176,7 @@ class Engine():
|
|||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||
self.module.load_state_dict(state['module'])
|
||||
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
|
||||
|
||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||
|
@ -344,7 +344,7 @@ class Engines(dict[str, Engine]):
|
|||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||
|
||||
torch.save(state_dict, save_path)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
print(f"Exported {name} to {save_path}")
|
||||
|
||||
def save_checkpoint(self, tag=None):
|
||||
if not tag:
|
||||
|
|
|
@ -8,54 +8,6 @@ from .engines import load_engines
|
|||
from .config import cfg
|
||||
from .models.lora import lora_get_state_dict
|
||||
|
||||
# stitches embeddings into one embedding & classifier => lm_head
|
||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||
token_dim = 1024
|
||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
||||
embedding.weight.requires_grad = False
|
||||
|
||||
def move_value(k):
|
||||
v = state_dict['module'][k]
|
||||
del state_dict['module'][k]
|
||||
return v
|
||||
|
||||
separator = move_value('sep')
|
||||
out_proj = move_value('classifier.weight')
|
||||
text_emb = move_value('text_emb.weight')
|
||||
langs_emb = move_value('langs_emb.weight')
|
||||
tasks_emb = move_value('tasks_emb.weight')
|
||||
tones_emb = move_value('tones_emb.weight')
|
||||
|
||||
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||
|
||||
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||
|
||||
|
||||
start = 0
|
||||
for i in range(256):
|
||||
embedding.weight[start + i] = text_emb[i]
|
||||
|
||||
start = 256
|
||||
for layer in range(8):
|
||||
for i in range(1024):
|
||||
offset = start + 1024 * layer
|
||||
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
|
||||
|
||||
start = 256 + 1024 * 8
|
||||
for layer in range(8):
|
||||
for i in range(1024):
|
||||
offset = start + 1024 * layer
|
||||
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
|
||||
|
||||
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
||||
state_dict['module']['lm_head.weight'] = out_proj
|
||||
del state_dict['module']['classifier.bias']
|
||||
|
||||
return state_dict
|
||||
|
||||
def extract_lora( state_dict, config = None, save_path = None ):
|
||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||
# should always be included, but just in case
|
||||
|
@ -79,7 +31,6 @@ def extract_lora( state_dict, config = None, save_path = None ):
|
|||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
|
@ -87,14 +38,9 @@ def main():
|
|||
cfg.trainer.load_module_only = True
|
||||
|
||||
callback = None
|
||||
if args.hf:
|
||||
callback = convert_to_hf
|
||||
elif args.lora:
|
||||
if args.lora:
|
||||
callback = extract_lora
|
||||
|
||||
if args.hf and args.lora:
|
||||
raise Exception("Requesting more than one callback")
|
||||
|
||||
engines = load_engines()
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ def train_feeder(engine, batch):
|
|||
|
||||
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
|
||||
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
|
||||
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
|
||||
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token )
|
||||
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
|
||||
|
||||
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
|
||||
|
|
|
@ -131,10 +131,6 @@ def train(
|
|||
_logger.info(cfg)
|
||||
"""
|
||||
|
||||
# Setup global engines
|
||||
global _engines
|
||||
_engines = engines
|
||||
|
||||
events = []
|
||||
|
||||
eval_fn = global_leader_only(eval_fn)
|
||||
|
|
Loading…
Reference in New Issue
Block a user