oops, needed some fixes
This commit is contained in:
parent
286681c87c
commit
6ee5f21ddc
|
@ -118,10 +118,15 @@ def load_engines(training=True):
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
|
checkpoint_path = cfg.ckpt_dir / name / "latest"
|
||||||
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
# automatically load from state dict if one is provided, but no DeepSpeed checkpoint is present
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
|
|
||||||
if not loads_state_dict and not (cfg.ckpt_dir / name / "latest").exists() and load_path.exists():
|
# actually use the lora-specific checkpoint if available
|
||||||
|
if cfg.lora is not None:
|
||||||
|
checkpoint_path = cfg.ckpt_dir / lora.full_name / "latest"
|
||||||
|
|
||||||
|
if not loads_state_dict and not checkpoint_path.exists() and load_path.exists():
|
||||||
print("Checkpoint missing, but weights found.")
|
print("Checkpoint missing, but weights found.")
|
||||||
loads_state_dict = True
|
loads_state_dict = True
|
||||||
|
|
||||||
|
|
|
@ -176,7 +176,7 @@ class Engine():
|
||||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||||
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||||
self.module.load_state_dict(state['module'])
|
self.module.load_state_dict(state['module'], strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||||
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
load_lr_scheduler_states = load_lr_scheduler_states and self.lr_scheduler is not None and 'lr_scheduler' in state
|
||||||
|
@ -344,7 +344,7 @@ class Engines(dict[str, Engine]):
|
||||||
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
state_dict = callback( state_dict, config = engine.hyper_config, save_path = save_path )
|
||||||
|
|
||||||
torch.save(state_dict, save_path)
|
torch.save(state_dict, save_path)
|
||||||
print(f"Exported {name} to {outpath}")
|
print(f"Exported {name} to {save_path}")
|
||||||
|
|
||||||
def save_checkpoint(self, tag=None):
|
def save_checkpoint(self, tag=None):
|
||||||
if not tag:
|
if not tag:
|
||||||
|
|
|
@ -8,54 +8,6 @@ from .engines import load_engines
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
from .models.lora import lora_get_state_dict
|
from .models.lora import lora_get_state_dict
|
||||||
|
|
||||||
# stitches embeddings into one embedding & classifier => lm_head
|
|
||||||
def convert_to_hf( state_dict, config = None, save_path = None ):
|
|
||||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
|
||||||
token_dim = 1024
|
|
||||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
|
||||||
embedding.weight.requires_grad = False
|
|
||||||
|
|
||||||
def move_value(k):
|
|
||||||
v = state_dict['module'][k]
|
|
||||||
del state_dict['module'][k]
|
|
||||||
return v
|
|
||||||
|
|
||||||
separator = move_value('sep')
|
|
||||||
out_proj = move_value('classifier.weight')
|
|
||||||
text_emb = move_value('text_emb.weight')
|
|
||||||
langs_emb = move_value('langs_emb.weight')
|
|
||||||
tasks_emb = move_value('tasks_emb.weight')
|
|
||||||
tones_emb = move_value('tones_emb.weight')
|
|
||||||
|
|
||||||
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
|
||||||
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
|
||||||
|
|
||||||
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
|
|
||||||
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
|
|
||||||
|
|
||||||
|
|
||||||
start = 0
|
|
||||||
for i in range(256):
|
|
||||||
embedding.weight[start + i] = text_emb[i]
|
|
||||||
|
|
||||||
start = 256
|
|
||||||
for layer in range(8):
|
|
||||||
for i in range(1024):
|
|
||||||
offset = start + 1024 * layer
|
|
||||||
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
|
|
||||||
|
|
||||||
start = 256 + 1024 * 8
|
|
||||||
for layer in range(8):
|
|
||||||
for i in range(1024):
|
|
||||||
offset = start + 1024 * layer
|
|
||||||
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
|
|
||||||
|
|
||||||
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
|
||||||
state_dict['module']['lm_head.weight'] = out_proj
|
|
||||||
del state_dict['module']['classifier.bias']
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
def extract_lora( state_dict, config = None, save_path = None ):
|
def extract_lora( state_dict, config = None, save_path = None ):
|
||||||
lora = state_dict["lora"] if "lora" in state_dict else None
|
lora = state_dict["lora"] if "lora" in state_dict else None
|
||||||
# should always be included, but just in case
|
# should always be included, but just in case
|
||||||
|
@ -79,7 +31,6 @@ def extract_lora( state_dict, config = None, save_path = None ):
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--module-only", action='store_true')
|
parser.add_argument("--module-only", action='store_true')
|
||||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
|
||||||
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
parser.add_argument("--lora", action='store_true', default=None) # exports LoRA
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
|
@ -87,14 +38,9 @@ def main():
|
||||||
cfg.trainer.load_module_only = True
|
cfg.trainer.load_module_only = True
|
||||||
|
|
||||||
callback = None
|
callback = None
|
||||||
if args.hf:
|
if args.lora:
|
||||||
callback = convert_to_hf
|
|
||||||
elif args.lora:
|
|
||||||
callback = extract_lora
|
callback = extract_lora
|
||||||
|
|
||||||
if args.hf and args.lora:
|
|
||||||
raise Exception("Requesting more than one callback")
|
|
||||||
|
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ def train_feeder(engine, batch):
|
||||||
|
|
||||||
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
|
text_tokens = pad_sequence([ text for text in batch["text"] ], batch_first = True)
|
||||||
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
|
text_lengths = torch.Tensor([ text.shape[0] for text in batch["text"] ]).to(dtype=torch.int32)
|
||||||
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = stop_mel_token )
|
mel_codes = pad_sequence([ codes[0] for codes in batch["mel"] ], batch_first = True, padding_value = engine.module.stop_mel_token )
|
||||||
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
|
wav_lengths = torch.Tensor([ x for x in batch["wav_length"] ]).to(dtype=torch.int32)
|
||||||
|
|
||||||
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
|
engine.forward(autoregressive_latents, text_tokens, text_lengths, mel_codes, wav_lengths)
|
||||||
|
|
|
@ -131,10 +131,6 @@ def train(
|
||||||
_logger.info(cfg)
|
_logger.info(cfg)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Setup global engines
|
|
||||||
global _engines
|
|
||||||
_engines = engines
|
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
|
|
||||||
eval_fn = global_leader_only(eval_fn)
|
eval_fn = global_leader_only(eval_fn)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user