vall-e/vall_e/export.py
2024-06-09 11:39:43 -05:00

77 lines
2.5 KiB
Python
Executable File

import argparse
import torch
import torch.nn
from .data import get_phone_symmap
from .engines import load_engines
from .config import cfg
# stitches embeddings into one embedding + classifier => lm_head
def convert_to_hf( state_dict, config = None ):
n_tokens = 256 + (1024 * 8) + (1024 * 8) + 1
token_dim = 1024
embedding = torch.nn.Embedding(n_tokens, token_dim)
embedding.weight.requires_grad = False
def move_value(k):
v = state_dict['module'][k]
del state_dict['module'][k]
return v
separator = move_value('sep')
out_proj = move_value('classifier.weight')
text_emb = move_value('text_emb.weight')
langs_emb = move_value('langs_emb.weight')
tasks_emb = move_value('tasks_emb.weight')
tones_emb = move_value('tones_emb.weight')
proms_emb_weight = [ move_value(f'proms_emb.weight.{i}').item() for i in range(8) ] if "proms_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
resps_emb_weight = [ move_value(f'resps_emb.weight.{i}').item() for i in range(8) ] if "resps_emb.weight.0" in state_dict['module'] else [ [ 1 for _ in range(8) ] ]
proms_emb = [ move_value(f'proms_emb.embeddings.{i}.weight') for i in range(8) ]
resps_emb = [ move_value(f'resps_emb.embeddings.{i}.weight') for i in range(8) ]
start = 0
for i in range(256):
embedding.weight[start + i] = text_emb[i]
start = 256
for layer in range(8):
for i in range(1024):
offset = start + 1024 * layer
embedding.weight[i + offset] = proms_emb[layer][i] * proms_emb_weight[layer]
start = 256 + 1024 * 8
for layer in range(8):
for i in range(1024):
offset = start + 1024 * layer
embedding.weight[i + offset] = resps_emb[layer][i] * proms_emb_weight[layer]
state_dict['module']['model.embed_tokens.weight'] = embedding.state_dict()
state_dict['module']['lm_head.weight'] = out_proj
del state_dict['module']['classifier.bias']
torch.save(state_dict, "./data/export_test.pth")
raise Exception("!")
return state_dict
def main():
parser = argparse.ArgumentParser("Save trained model to path.")
parser.add_argument("--module-only", action='store_true')
parser.add_argument("--hf", action='store_true', default=None) # convert to HF-style
args, unknown = parser.parse_known_args()
if args.module_only:
cfg.trainer.load_module_only = True
callback = convert_to_hf if args.hf else None
engines = load_engines()
engines.export(userdata={"symmap": get_phone_symmap()}, callback=callback)
if __name__ == "__main__":
main()