added a flag to convert to a HF compatible model on export by stitching things
This commit is contained in:
parent
934672252b
commit
e50edc3b48
|
@ -294,7 +294,7 @@ class Engines(dict[str, Engine]):
|
||||||
for engine in self.values():
|
for engine in self.values():
|
||||||
engine.dispatch_attribute(*args, **kwargs)
|
engine.dispatch_attribute(*args, **kwargs)
|
||||||
|
|
||||||
def export(self, userdata={}):
|
def export(self, userdata={}, callback=None):
|
||||||
for name, engine in self.items():
|
for name, engine in self.items():
|
||||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
@ -307,6 +307,8 @@ class Engines(dict[str, Engine]):
|
||||||
},
|
},
|
||||||
"userdata": userdata
|
"userdata": userdata
|
||||||
}
|
}
|
||||||
|
if callback:
|
||||||
|
state_dict = callback( state_dict, engine.module )
|
||||||
torch.save(state_dict, outpath)
|
torch.save(state_dict, outpath)
|
||||||
print(f"Exported {name} to {outpath}")
|
print(f"Exported {name} to {outpath}")
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,77 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn
|
||||||
|
|
||||||
from .data import get_phone_symmap
|
from .data import get_phone_symmap
|
||||||
from .engines import load_engines
|
from .engines import load_engines
|
||||||
from .config import cfg
|
from .config import cfg
|
||||||
|
|
||||||
|
# stitches embeddings into one embedding + classifier => lm_head
|
||||||
|
def convert_to_hf( state_dict, config = None ):
|
||||||
|
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
|
||||||
|
token_dim = 1024
|
||||||
|
embedding = torch.nn.Embedding(n_tokens, token_dim)
|
||||||
|
embedding.weight.requires_grad = False
|
||||||
|
|
||||||
|
def move_value(k):
|
||||||
|
v = state_dict['module'][k]
|
||||||
|
del state_dict['module'][k]
|
||||||
|
return v
|
||||||
|
|
||||||
|
separator = move_value('sep')
|
||||||
|
out_proj = move_value('classifier.weight')
|
||||||
|
text_emb = move_value('text_emb.weight')
|
||||||
|
langs_emb = move_value('langs_emb.weight')
|
||||||
|
tasks_emb = move_value('tasks_emb.weight')
|
||||||
|
tones_emb = move_value('tones_emb.weight')
|
||||||
|
|
||||||
|
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||||
|
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
|
||||||
|
|
||||||
|
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||||
|
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
|
||||||
|
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for i in range(256):
|
||||||
|
embedding.weight[start + i] = text_emb[i]
|
||||||
|
|
||||||
|
start = 256
|
||||||
|
for layer in range(8):
|
||||||
|
for i in range(1024):
|
||||||
|
offset = start + 1024 * layer
|
||||||
|
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
|
||||||
|
|
||||||
|
start = 256 + 1024 * 8
|
||||||
|
for layer in range(8):
|
||||||
|
for i in range(1024):
|
||||||
|
offset = start + 1024 * layer
|
||||||
|
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
|
||||||
|
|
||||||
|
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
|
||||||
|
state_dict['module']['lm_head.weight'] = out_proj
|
||||||
|
del state_dict['module']['classifier.bias']
|
||||||
|
|
||||||
|
torch.save(state_dict, "./data/export_test.pth")
|
||||||
|
|
||||||
|
raise Exception("!")
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser("Save trained model to path.")
|
parser = argparse.ArgumentParser("Save trained model to path.")
|
||||||
parser.add_argument("--module-only", action='store_true')
|
parser.add_argument("--module-only", action='store_true')
|
||||||
|
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.module_only:
|
if args.module_only:
|
||||||
cfg.trainer.load_module_only = True
|
cfg.trainer.load_module_only = True
|
||||||
|
|
||||||
|
callback = convert_to_hf if args.hf else None
|
||||||
|
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
engines.export(userdata={"symmap": get_phone_symmap()})
|
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
|
@ -135,11 +135,9 @@ def run_eval(engines, eval_name, dl):
|
||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser("VALL-E TTS")
|
parser = argparse.ArgumentParser("VALL-E TTS")
|
||||||
parser.add_argument("--eval", action="store_true")
|
parser.add_argument("--eval", action="store_true", default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
"""
|
|
||||||
|
|
||||||
setup_logging(cfg.log_dir)
|
setup_logging(cfg.log_dir)
|
||||||
|
|
||||||
|
@ -162,10 +160,8 @@ def train():
|
||||||
|
|
||||||
qnt.unload_model()
|
qnt.unload_model()
|
||||||
|
|
||||||
"""
|
|
||||||
if args.eval:
|
if args.eval:
|
||||||
return eval_fn(engines=trainer.load_engines())
|
return eval_fn(engines=trainer.load_engines())
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if cfg.trainer.load_webui:
|
if cfg.trainer.load_webui:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user