28 lines
1.4 KiB
Python
28 lines
1.4 KiB
Python
import torch
|
|
|
|
action = None
|
|
# copies the resp_embs from a given AR and NAR into an AR as a base to convert into an AR+NAR monolithic odel
|
|
if action == "merge_resp_embs":
|
|
src_ar = torch.load("./data/source-ar.pth", map_location="cpu")
|
|
src_nar = torch.load("./data/source-nar.pth", map_location="cpu")
|
|
# copies all weights from the AR since the AR is usually "better", might need to experiment more with using a NAR as the base
|
|
dst = torch.load("./data/source-ar.pth", map_location="cpu")
|
|
|
|
# copy resps_emb to layer 0 from AR
|
|
dst['module']['resps_emb.weight'][:0, :, :] = src_ar['module']['resps_emb.weight']
|
|
# copy resps_emb to remaining layers from NAR
|
|
dst['module']['resps_emb.weight'][1:, :-1, :] = src_nar['module']['resps_emb.weight']
|
|
# copies an existing AR+NAR monolithic model's resp_emb onto an AR
|
|
elif action == "copy_resps_emb":
|
|
src = torch.load("./data/source.pth", map_location="cpu")
|
|
dst = torch.load("./data/destination.pth", map_location="cpu")
|
|
dst['module']['resps_emb.weight'] = src['module']['resps_emb.weight']
|
|
elif action == "extend_resps_emb":
|
|
dst = torch.load("./data/destination.pth", map_location="cpu")
|
|
dst['module']['resps_emb.weight'] = dst['module']['resps_emb.weight'].expand(4, -1, -1)
|
|
dst['module']['resps_emb.weight'][1:] = torch.randn(3, 1025, 1024)
|
|
|
|
else
|
|
raise Exception(f"invalid action: {action}")
|
|
|
|
torch.save(dst, './data/fp32.pth') |