oversight with using resize_modules
This commit is contained in:
parent
808a79ebaf
commit
7a77978096
|
@ -186,8 +186,8 @@ def load_engines(training=True):
|
||||||
("rvq_l_emb.weight", model.config.resp_levels ),
|
("rvq_l_emb.weight", model.config.resp_levels ),
|
||||||
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ),
|
||||||
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ),
|
||||||
("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ),
|
("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ),
|
||||||
("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ),
|
("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ),
|
||||||
]
|
]
|
||||||
for k, tokens in keys:
|
for k, tokens in keys:
|
||||||
state[k] = ml.resize_weight( state[k], tokens )
|
state[k] = ml.resize_weight( state[k], tokens )
|
||||||
|
|
|
@ -396,6 +396,9 @@ def get_module_size( module ):
|
||||||
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
|
buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()])
|
||||||
return param_size + buffer_size
|
return param_size + buffer_size
|
||||||
|
|
||||||
|
# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way
|
||||||
|
# it'd be better to just attach to layers itself rather than every single module
|
||||||
|
|
||||||
# assigns modules to requested devices for a given policy
|
# assigns modules to requested devices for a given policy
|
||||||
def get_model_offload_policy(module, policy=None):
|
def get_model_offload_policy(module, policy=None):
|
||||||
# handle any other weird values this is set to
|
# handle any other weird values this is set to
|
||||||
|
|
Loading…
Reference in New Issue
Block a user