diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 1a87c42..6d0f083 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -186,8 +186,8 @@ def load_engines(training=True): ("rvq_l_emb.weight", model.config.resp_levels ), ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ), ] for k, tokens in keys: state[k] = ml.resize_weight( state[k], tokens ) diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index f4c4d85..4cc6133 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -396,6 +396,9 @@ def get_module_size( module ): buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()]) return param_size + buffer_size +# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way +# it'd be better to just attach to layers itself rather than every single module + # assigns modules to requested devices for a given policy def get_model_offload_policy(module, policy=None): # handle any other weird values this is set to