added a flag to convert to a HF compatible model on export by stitching things
This commit is contained in:
parent
934672252b
commit
e50edc3b48
|
@ -294,7 +294,7 @@ class Engines(dict[str, Engine]):
|
|||
for engine in self.values():
|
||||
engine.dispatch_attribute(*args, **kwargs)
|
||||
|
||||
def export(self, userdata={}):
|
||||
def export(self, userdata={}, callback=None):
|
||||
for name, engine in self.items():
|
||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state_dict = {
|
||||
|
@ -307,6 +307,8 @@ class Engines(dict[str, Engine]):
|
|||
},
|
||||
"userdata": userdata
|
||||
}
|
||||
if callback:
|
||||
state_dict = callback( state_dict, engine.module )
|
||||
torch.save(state_dict, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
|
|
|
@ -1,21 +1,77 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
import torch.nn
|
||||
|
||||
from .data import get_phone_symmap
|
||||
from .engines import load_engines
|
||||
from .config import cfg
|
||||
|
||||
# stitches embeddings into one embedding + classifier => lm_head
|
||||
def convert_to_hf( state_dict, config = None ):
|
||||
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||
token_dim = 1024
|
||||
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
||||
embedding.weight.requires_grad = False
|
||||
|
||||
def move_value(k):
|
||||
v = state_dict['module'][k]
|
||||
del state_dict['module'][k]
|
||||
return v
|
||||
|
||||
separator = move_value('sep')
|
||||
out_proj = move_value('classifier.weight')
|
||||
text_emb = move_value('text_emb.weight')
|
||||
langs_emb = move_value('langs_emb.weight')
|
||||
tasks_emb = move_value('tasks_emb.weight')
|
||||
tones_emb = move_value('tones_emb.weight')
|
||||
|
||||
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||
|
||||
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||
|
||||
|
||||
start = 0
|
||||
for i in range(256):
|
||||
embedding.weight[start + i] = text_emb[i]
|
||||
|
||||
start = 256
|
||||
for layer in range(8):
|
||||
for i in range(1024):
|
||||
offset = start + 1024 * layer
|
||||
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
|
||||
|
||||
start = 256 + 1024 * 8
|
||||
for layer in range(8):
|
||||
for i in range(1024):
|
||||
offset = start + 1024 * layer
|
||||
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
|
||||
|
||||
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
||||
state_dict['module']['lm_head.weight'] = out_proj
|
||||
del state_dict['module']['classifier.bias']
|
||||
|
||||
torch.save(state_dict, "./data/export_test.pth")
|
||||
|
||||
raise Exception("!")
|
||||
|
||||
return state_dict
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||
parser.add_argument("--module-only", action='store_true')
|
||||
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.module_only:
|
||||
cfg.trainer.load_module_only = True
|
||||
|
||||
callback = convert_to_hf if args.hf else None
|
||||
|
||||
engines = load_engines()
|
||||
engines.export(userdata={"symmap": get_phone_symmap()})
|
||||
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -135,11 +135,9 @@ def run_eval(engines, eval_name, dl):
|
|||
|
||||
|
||||
def train():
|
||||
"""
|
||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||
parser.add_argument("--eval", action="store_true")
|
||||
parser.add_argument("--eval", action="store_true", default=None)
|
||||
args = parser.parse_args()
|
||||
"""
|
||||
|
||||
setup_logging(cfg.log_dir)
|
||||
|
||||
|
@ -162,10 +160,8 @@ def train():
|
|||
|
||||
qnt.unload_model()
|
||||
|
||||
"""
|
||||
if args.eval:
|
||||
return eval_fn(engines=trainer.load_engines())
|
||||
"""
|
||||
|
||||
"""
|
||||
if cfg.trainer.load_webui:
|
||||
|
|
Loading…
Reference in New Issue
Block a user