vall-e/scripts/stitch_embs.py

28 lines
1.4 KiB
Python
Raw Normal View History

import torch
action = None
# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel
if action == "merge_resp_embs":
src_ar = torch.load("./data/source-ar.pth", map_location="cpu")
src_nar = torch.load("./data/source-nar.pth", map_location="cpu")
# copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base
dst = torch.load("./data/source-ar.pth", map_location="cpu")
# copy resps_emb to layer 0 from AR
dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight']
# copy resps_emb to remaining layers from NAR
dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight']
# copies an existing AR+NAR monolithic model's resp_emb onto an AR
elif action == "copy_resps_emb":
src = torch.load("./data/source.pth", map_location="cpu")
dst = torch.load("./data/destination.pth", map_location="cpu")
dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight']
elif action == "extend_resps_emb":
dst = torch.load("./data/destination.pth", map_location="cpu")
dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1)
dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024)
else
raise Exception(f"invalid action: {action}")
torch.save(dst, './data/fp32.pth')