From 7a7797809653ce60e96f8824a68f97613df8f9d7 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 2 Aug 2024 20:28:49 -0500 Subject: [PATCH] oversight with using resize_modules --- vall_e/engines/__init__.py | 4 ++-- vall_e/utils/utils.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 1a87c42..6d0f083 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -186,8 +186,8 @@ def load_engines(training=True): ("rvq_l_emb.weight", model.config.resp_levels ), ("resps_emb.embeddings.0.weight", model.config.audio_tokens + uses_stop_token ), ("model.embed_tokens.weight", model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.weight", model.config.audio_tokens + uses_stop_token ), - ("classifiers.proj.0.bias", model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.weight" if model.config.experimental.split_classifiers else 'classifier.weight', model.config.audio_tokens + uses_stop_token ), + ("classifiers.proj.0.bias" if model.config.experimental.split_classifiers else 'classifier.bias', model.config.audio_tokens + uses_stop_token ), ] for k, tokens in keys: state[k] = ml.resize_weight( state[k], tokens ) diff --git a/vall_e/utils/utils.py b/vall_e/utils/utils.py index f4c4d85..4cc6133 100755 --- a/vall_e/utils/utils.py +++ b/vall_e/utils/utils.py @@ -396,6 +396,9 @@ def get_module_size( module ): buffer_size = sum([b.nelement() * b.element_size() for b in module.buffers()]) return param_size + buffer_size +# to-do: rewrite all this shit, I don't know what I was thinking when implementing it this way +# it'd be better to just attach to layers itself rather than every single module + # assigns modules to requested devices for a given policy def get_model_offload_policy(module, policy=None): # handle any other weird values this is set to