restructured some things with the model to remove dead weights

This commit is contained in:
mrq 2023-09-20 19:10:59 -05:00
parent a6bfe43590
commit c0b25541e3
11 changed files with 166 additions and 191 deletions

View File

@ -57,7 +57,7 @@ setup(
"auraloss[all]", "auraloss[all]",
"vocos", "vocos",
"h5py", "h5py",
"torchscale @ git+https://github.com/microsoft/torchscale", "torchscale @ git+https://git.ecker.tech/mrq/torchscale",
], ],
url="https://git.ecker.tech/mrq/vall-e", url="https://git.ecker.tech/mrq/vall-e",
) )

View File

@ -157,16 +157,16 @@ class Dataset:
@dataclass() @dataclass()
class Model: class Model:
name: str = "" name: str = ""
version: int = 1
size: str | float | dict = "full" size: str | float | dict = "full"
resp_levels: int = 1 resp_levels: int = 1
prom_levels: int = 8 prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc") tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
langs: int = 0 # defined languages
arch_type: str = "retnet" arch_type: str = "retnet"
training: bool = True training: bool = True
interleave: bool = False interleave: bool = False
frozen_params: list[str] = field(default_factory=lambda: []) frozen_params: list[str] = field(default_factory=lambda: [])
p_ar_nar: float = 0.5
version: int = 1
@property @property
def full_name(self): def full_name(self):
@ -240,8 +240,8 @@ class Models:
_prom_levels: int = 1 _prom_levels: int = 1
_models: list[Model] = field(default_factory=lambda: [ _models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False), Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False), Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
]) ])
def get(self, name=None): def get(self, name=None):

View File

@ -59,10 +59,10 @@ class Engine():
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = 0 self.global_steps = kwargs.pop("global_steps", 0)
self.micro_steps = 0 self.micro_steps = kwargs.pop("micro_steps", 0)
self.global_samples = 0 self.global_samples = kwargs.pop("global_samples", 0)
self.tokens_processed = 0 self.tokens_processed = kwargs.pop("tokens_processed", 0)
self._frozen_params = set() self._frozen_params = set()
@ -117,10 +117,12 @@ class Engine():
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None, "optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None, "lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
"global_step": self.global_step, "stats": {
"micro_step": self.micro_step, "global_step": self.global_step,
"global_samples": self.global_samples, "micro_step": self.micro_step,
"tokens_processed": self.tokens_processed, "global_samples": self.global_samples,
"tokens_processed": self.tokens_processed,
}
}, save_path) }, save_path)
open(save_dir / "latest", 'w').write( tag ) open(save_dir / "latest", 'w').write( tag )
@ -137,10 +139,10 @@ class Engine():
return return
state = torch.load(load_path, map_location=torch.device(cfg.device)) state = torch.load(load_path, map_location=torch.device(cfg.device))
self.global_steps = state['global_step'] self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.micro_steps = state['micro_step'] self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['global_samples'] self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
self.tokens_processed = state['tokens_processed'] self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module']) self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
@ -261,12 +263,14 @@ class Engines(dict[str, Engine]):
outpath = cfg.ckpt_dir / name / "fp32.pth" outpath = cfg.ckpt_dir / name / "fp32.pth"
state_dict = { state_dict = {
'module': engine.module.state_dict(), 'module': engine.module.state_dict(),
"global_step": engine.global_step, "stats": {
"micro_step": engine.micro_step, "global_step": engine.global_step,
"global_samples": engine.global_samples, "micro_step": engine.micro_step,
"tokens_processed": engine.tokens_processed, "global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata
} }
state_dict.update(userdata)
torch.save(state_dict, outpath) torch.save(state_dict, outpath)
print(f"Exported {name} to {outpath}") print(f"Exported {name} to {outpath}")

View File

@ -39,10 +39,24 @@ class Engine(DeepSpeedEngine):
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config']) kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
if "stats" in kwargs:
# stats COULD be = None
stats = kwargs.pop('stats')
if stats is None:
stats = {
"global_steps": 0,
"micro_steps": 0,
"global_samples": 0,
"tokens_processed": 0,
}
super().__init__(None, *args, **kwargs) super().__init__(None, *args, **kwargs)
self._frozen_params = set() self._frozen_params = set()
self.tokens_processed = 0 self.global_steps = stats["global_steps"]
self.micro_steps = stats["micro_steps"]
self.global_samples = stats["global_samples"]
self.tokens_processed = stats["tokens_processed"]
def freeze(self, freeze_all=True): def freeze(self, freeze_all=True):
if self._cfg is None or not hasattr(self._cfg, "frozen_params"): if self._cfg is None or not hasattr(self._cfg, "frozen_params"):

View File

@ -34,7 +34,7 @@ class TTS():
if amp is None: if amp is None:
amp = cfg.inference.amp amp = cfg.inference.amp
if dtype is None: if dtype is None:
dtype = cfg.inference.dtype dtype = cfg.inference.weight_dtype
if device is None: if device is None:
device = cfg.device device = cfg.device
@ -50,43 +50,41 @@ class TTS():
self.amp = amp self.amp = amp
self.symmap = None self.symmap = None
def parse( name, model, state ):
if "userdata" in state and 'symmap' in state['userdata']:
self.symmap = state['userdata']['symmap']
elif "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
model.load_state_dict(state)
return model
if ar_ckpt and nar_ckpt: if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get()) models = get_models(cfg.models.get())
for name, model in models.items(): for name, model in models.items():
if name.startswith("ar+nar"): if name.startswith("ar"):
self.ar = model
state = torch.load(self.ar_ckpt) state = torch.load(self.ar_ckpt)
if "symmap" in state: self.ar = parse( name, model, state )
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.ar
elif name.startswith("ar"):
self.ar = model
state = torch.load(self.ar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
elif name.startswith("nar"): elif name.startswith("nar"):
self.nar = model
state = torch.load(self.nar_ckpt) state = torch.load(self.nar_ckpt)
if "symmap" in state: self.nar = parse( name, model, state )
self.symmap = state['symmap']
if "module" in state: if name.startswith("ar+nar"):
state = state['module'] self.nar = self.ar
self.nar.load_state_dict(state)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
else: else:
self.load_models() self.load_models()
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
if self.symmap is None: if self.symmap is None:
self.symmap = get_phone_symmap() self.symmap = get_phone_symmap()
@ -98,13 +96,13 @@ class TTS():
def load_models( self ): def load_models( self ):
engines = load_engines() engines = load_engines()
for name, engine in engines.items(): for name, engine in engines.items():
if name[:6] == "ar+nar": if name.startswith("ar"):
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32) self.ar = engine.module
elif name.startswith("nar"):
self.nar = engine.module
if name.startswith("ar+nar"):
self.nar = self.ar self.nar = self.ar
elif name[:2] == "ar":
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
elif name[:3] == "nar":
self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
def encode_text( self, text, language="en" ): def encode_text( self, text, language="en" ):
# already a tensor, return it # already a tensor, return it

View File

@ -41,12 +41,24 @@ class AR(Base):
def n_tasks(self) -> int: def n_tasks(self) -> int:
return cfg.models.tasks return cfg.models.tasks
@property
def n_langs(self) -> int:
return cfg.models.langs
@property @property
def recurrent_chunk_size(self) -> int: def recurrent_chunk_size(self) -> int:
if cfg.mode == "training": if cfg.mode == "training":
return 0 return 0
return cfg.inference.recurrent_chunk_size return cfg.inference.recurrent_chunk_size
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.ar.rotary_embedding_base
"""
@property @property
def interleave(self) -> bool: def interleave(self) -> bool:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:

View File

@ -47,6 +47,14 @@ class AR_NAR(Base):
def recurrent_chunk_size(self) -> int: def recurrent_chunk_size(self) -> int:
return 0 return 0
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.ar_nar.rotary_embedding_base
"""
@property @property
def interleave(self) -> bool: def interleave(self) -> bool:
return False return False
@ -293,6 +301,10 @@ def example_usage():
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4) optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer) engine = Engine(model=model, optimizer=optimizer)
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode() @torch.inference_mode()

View File

@ -235,11 +235,9 @@ class MultiEmbedding(nn.Module):
# Embedding that sums each RVQ-bin level within a given input acoustic prompt # Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module): class AudioEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim): def __init__(self, l_tokens, token_dim):
super().__init__() super().__init__()
self.n_levels = n_levels self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]: def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = [] res_list = []
@ -283,6 +281,10 @@ class Base(nn.Module):
def n_max_levels(self) -> int: def n_max_levels(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
def n_langs(self) -> int:
raise NotImplementedError
@property @property
def n_tasks(self) -> int: def n_tasks(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -290,6 +292,10 @@ class Base(nn.Module):
@property @property
def recurrent_chunk_size(self) -> int: def recurrent_chunk_size(self) -> int:
raise NotImplementedError raise NotImplementedError
@property
def rotary_embedding_base(self) -> float:
return 10000
@property @property
def interleave(self) -> bool: def interleave(self) -> bool:
@ -341,17 +347,24 @@ class Base(nn.Module):
self.n_layers = n_layers self.n_layers = n_layers
# +1 to include the stop token # +1 to include the stop token
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task # to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
n_prom_tokens = n_tokens
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
self.text_emb = Embedding(n_tokens, d_model) self.text_emb = Embedding(n_tokens, d_model)
if self.version == 1: # legacy if self.version == 1: # legacy
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model) self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic) self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else: else:
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model) # [1024] * 8
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model) self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model)
# [1025] + [1024] * 8
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
# self.langs_emb = Embedding(self.n_langs, d_model)
# self.tasks_emb = Embedding(self.n_tasks, d_model)
self.sep = nn.Parameter(torch.randn(d_model)) self.sep = nn.Parameter(torch.randn(d_model))
@ -365,7 +378,6 @@ class Base(nn.Module):
norm_type=self.norm_type, norm_type=self.norm_type,
n_levels=self.n_resp_levels, n_levels=self.n_resp_levels,
) for _ in range(n_layers) ]) ) for _ in range(n_layers) ])
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
self.retnet = RetNetDecoder(RetNetConfig( self.retnet = RetNetDecoder(RetNetConfig(
vocab_size=n_tokens, vocab_size=n_tokens,
@ -380,6 +392,8 @@ class Base(nn.Module):
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0, recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
no_output_layer=True, no_output_layer=True,
decoder_normalize_before=True, decoder_normalize_before=True,
rotary_embedding_base=self.rotary_embedding_base, # 10000
)) ))
self.classifier = nn.Linear(d_model, n_resp_tokens) self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -407,12 +421,17 @@ class Base(nn.Module):
resps_list: list[Tensor], resps_list: list[Tensor],
targ_list: list[Tensor] | None = None, targ_list: list[Tensor] | None = None,
#langs_list: list[Tensor] | None = None,
#tasks_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None, quant_levels: Tensor | None = None,
state: dict | None = None, state: dict | None = None,
): ):
x_list = self._samplewise_merge_tensors( x_list = self._samplewise_merge_tensors(
self.text_emb(text_list), self.text_emb(text_list),
#self.langs_emb(langs_list),
self.proms_emb(proms_list), self.proms_emb(proms_list),
#self.tasks_emb(tasks_list),
self.resps_emb(resps_list, quant_levels), self.resps_emb(resps_list, quant_levels),
sep=self.sep, sep=self.sep,
) )
@ -422,7 +441,7 @@ class Base(nn.Module):
batch_size = len(text_list) batch_size = len(text_list)
device = x.device device = x.device
if state is not None: if state is not None and self.arch_type == "retnet":
# prefill # prefill
if len(state) == 0: if len(state) == 0:
prefill_size = x.shape[1] prefill_size = x.shape[1]
@ -443,7 +462,6 @@ class Base(nn.Module):
# pass our inputs through the transformer # pass our inputs through the transformer
for block in self.blocks: for block in self.blocks:
x = block(x, m, l) x = block(x, m, l)
elif self.arch_type == "retnet": elif self.arch_type == "retnet":
# pass our inputs through the RetNet # pass our inputs through the RetNet
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True) x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)

View File

@ -39,6 +39,10 @@ class NAR(Base):
def n_tasks(self) -> int: def n_tasks(self) -> int:
return cfg.models.tasks return cfg.models.tasks
@property
def n_langs(self) -> int:
return cfg.models.langs
@property @property
def version(self) -> int: def version(self) -> int:
if hasattr(self, "config") and self.config: if hasattr(self, "config") and self.config:
@ -49,6 +53,14 @@ class NAR(Base):
def recurrent_chunk_size(self) -> int: def recurrent_chunk_size(self) -> int:
return 0 return 0
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.nar.rotary_embedding_base
"""
@property @property
def interleave(self) -> bool: def interleave(self) -> bool:
return False return False

View File

@ -1,71 +1,3 @@
"""
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py
# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this
"""
import uuid
from typing import Dict, Optional
from torch import Tensor
from torchscale.architecture.config import RetNetConfig from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder from torchscale.architecture.retnet import RetNetDecoder
# from retnet import RetNet
"""
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()
def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())
def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)
def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state
def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(
b for b in cls.__bases__ if b != FairseqIncrementalState
)
return cls
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder as Decoder
@with_incremental_state
class RetNetDecoder(Decoder):
def forward(self, src_tokens, **kwargs):
return super().forward(src_tokens, **kwargs)
def max_positions(self):
return self.args.max_token_positions
def reorder_incremental_state( self, incremental_state, new_order ):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
"""

View File

@ -64,84 +64,57 @@ def load_engines(invert=False):
lr_scheduler = None lr_scheduler = None
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer): if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None
params = {
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw": if cfg.hyperparameters.optimizer.lower() == "adamw":
params = { params["betas"] = (0.9, 0.96)
"lr": cfg.hyperparameters.learning_rate, params["eps"] = 1e-07
"betas": (0.9, 0.96), params["weight_decay"] = 0.01
"eps": 1e-07,
"weight_decay": 0.01, optimizer_class = ml.AdamW
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.AdamW(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
elif cfg.hyperparameters.optimizer.lower() == "sgd": elif cfg.hyperparameters.optimizer.lower() == "sgd":
params = { optimizer = ml.SGD
"lr": cfg.hyperparameters.learning_rate,
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.SGD(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
elif cfg.hyperparameters.optimizer.lower() == "prodigy": elif cfg.hyperparameters.optimizer.lower() == "prodigy":
params = { optimizer_class = ml.Prodigy
"lr": cfg.hyperparameters.learning_rate, else:
} raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.Prodigy( params.update(cfg.hyperparameters.optimizer_params)
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ], optimizer = optimizer_class(
**params, [ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
) **params,
)
# set up our LR scheduler here
if not model._cfg.training: if not model._cfg.training:
optimizer = None optimizer = None
lr_scheduler = None lr_scheduler = None
stats = None
if cfg.trainer.load_state_dict or not model._cfg.training: if cfg.trainer.load_state_dict or not model._cfg.training:
load_path = cfg.ckpt_dir / name / "fp32.pth" load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path, map_location=torch.device(cfg.device)) state = torch.load(load_path, map_location=torch.device(cfg.device))
# exporting the model from the zero_to_fp32.py exports the actual module's dict
# exporting with vall_e.export exports the state dict under .module # state dict is not just the module, extract the extra trainer details
if "stats" in state:
additionals = state["stats"]
if "module" in state: if "module" in state:
state = state["module"] state = state["module"]
# should decouple the following from this trainer script
# probably with passing a fun that defaults to a lambda x: x deal
"""
# can probably be done a lot more intelligently but oh well
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
# copy weights from the dict into the old portion
model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :]
# copy the full tensors back
state['proms_emb.weight'] = model.proms_emb.weight
# extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]:
o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape
n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape
# copy weights from the dict into the old portion
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
# copy the full tensors back
state['resps_emb.weight'] = model.resps_emb.weight
"""
model.load_state_dict(state, strict=cfg.trainer.strict_loading) model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# use base engine because DeepSpeed memory leaks # use base engine because DeepSpeed memory leaks if it's a non-training model
engines[name] = (Engine if model._cfg.training else _Engine)( engines[name] = (Engine if model._cfg.training else _Engine)(
#engines[name] = Engine(
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
_cfg=model._cfg, _cfg=model._cfg,
stats=stats
) )
engines = Engines(engines) engines = Engines(engines)