restructured some things with the model to remove dead weights
This commit is contained in:
parent
a6bfe43590
commit
c0b25541e3
2
setup.py
2
setup.py
|
@ -57,7 +57,7 @@ setup(
|
|||
"auraloss[all]",
|
||||
"vocos",
|
||||
"h5py",
|
||||
"torchscale @ git+https://github.com/microsoft/torchscale",
|
||||
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
|
||||
],
|
||||
url="https://git.ecker.tech/mrq/vall-e",
|
||||
)
|
||||
|
|
|
@ -157,16 +157,16 @@ class Dataset:
|
|||
@dataclass()
|
||||
class Model:
|
||||
name: str = ""
|
||||
version: int = 1
|
||||
size: str | float | dict = "full"
|
||||
resp_levels: int = 1
|
||||
prom_levels: int = 8
|
||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||
langs: int = 0 # defined languages
|
||||
arch_type: str = "retnet"
|
||||
training: bool = True
|
||||
interleave: bool = False
|
||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||
p_ar_nar: float = 0.5
|
||||
version: int = 1
|
||||
|
||||
@property
|
||||
def full_name(self):
|
||||
|
@ -240,8 +240,8 @@ class Models:
|
|||
_prom_levels: int = 1
|
||||
|
||||
_models: list[Model] = field(default_factory=lambda: [
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False),
|
||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
|
||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
|
||||
])
|
||||
|
||||
def get(self, name=None):
|
||||
|
|
|
@ -59,10 +59,10 @@ class Engine():
|
|||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||
|
||||
self.global_steps = 0
|
||||
self.micro_steps = 0
|
||||
self.global_samples = 0
|
||||
self.tokens_processed = 0
|
||||
self.global_steps = kwargs.pop("global_steps", 0)
|
||||
self.micro_steps = kwargs.pop("micro_steps", 0)
|
||||
self.global_samples = kwargs.pop("global_samples", 0)
|
||||
self.tokens_processed = kwargs.pop("tokens_processed", 0)
|
||||
|
||||
self._frozen_params = set()
|
||||
|
||||
|
@ -117,10 +117,12 @@ class Engine():
|
|||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"global_samples": self.global_samples,
|
||||
"tokens_processed": self.tokens_processed,
|
||||
"stats": {
|
||||
"global_step": self.global_step,
|
||||
"micro_step": self.micro_step,
|
||||
"global_samples": self.global_samples,
|
||||
"tokens_processed": self.tokens_processed,
|
||||
}
|
||||
}, save_path)
|
||||
|
||||
open(save_dir / "latest", 'w').write( tag )
|
||||
|
@ -137,10 +139,10 @@ class Engine():
|
|||
return
|
||||
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
self.global_steps = state['global_step']
|
||||
self.micro_steps = state['micro_step']
|
||||
self.global_samples = state['global_samples']
|
||||
self.tokens_processed = state['tokens_processed']
|
||||
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||
self.module.load_state_dict(state['module'])
|
||||
|
||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||
|
@ -261,12 +263,14 @@ class Engines(dict[str, Engine]):
|
|||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state_dict = {
|
||||
'module': engine.module.state_dict(),
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
"global_samples": engine.global_samples,
|
||||
"tokens_processed": engine.tokens_processed,
|
||||
"stats": {
|
||||
"global_step": engine.global_step,
|
||||
"micro_step": engine.micro_step,
|
||||
"global_samples": engine.global_samples,
|
||||
"tokens_processed": engine.tokens_processed,
|
||||
},
|
||||
"userdata": userdata
|
||||
}
|
||||
state_dict.update(userdata)
|
||||
torch.save(state_dict, outpath)
|
||||
print(f"Exported {name} to {outpath}")
|
||||
|
||||
|
|
|
@ -39,10 +39,24 @@ class Engine(DeepSpeedEngine):
|
|||
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
|
||||
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
||||
|
||||
if "stats" in kwargs:
|
||||
# stats COULD be = None
|
||||
stats = kwargs.pop('stats')
|
||||
if stats is None:
|
||||
stats = {
|
||||
"global_steps": 0,
|
||||
"micro_steps": 0,
|
||||
"global_samples": 0,
|
||||
"tokens_processed": 0,
|
||||
}
|
||||
|
||||
super().__init__(None, *args, **kwargs)
|
||||
self._frozen_params = set()
|
||||
|
||||
self.tokens_processed = 0
|
||||
self.global_steps = stats["global_steps"]
|
||||
self.micro_steps = stats["micro_steps"]
|
||||
self.global_samples = stats["global_samples"]
|
||||
self.tokens_processed = stats["tokens_processed"]
|
||||
|
||||
def freeze(self, freeze_all=True):
|
||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||
|
|
|
@ -34,7 +34,7 @@ class TTS():
|
|||
if amp is None:
|
||||
amp = cfg.inference.amp
|
||||
if dtype is None:
|
||||
dtype = cfg.inference.dtype
|
||||
dtype = cfg.inference.weight_dtype
|
||||
if device is None:
|
||||
device = cfg.device
|
||||
|
||||
|
@ -50,43 +50,41 @@ class TTS():
|
|||
self.amp = amp
|
||||
|
||||
self.symmap = None
|
||||
|
||||
def parse( name, model, state ):
|
||||
if "userdata" in state and 'symmap' in state['userdata']:
|
||||
self.symmap = state['userdata']['symmap']
|
||||
elif "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
if ar_ckpt and nar_ckpt:
|
||||
self.ar_ckpt = ar_ckpt
|
||||
self.nar_ckpt = nar_ckpt
|
||||
|
||||
models = get_models(cfg.models.get())
|
||||
|
||||
for name, model in models.items():
|
||||
if name.startswith("ar+nar"):
|
||||
self.ar = model
|
||||
if name.startswith("ar"):
|
||||
state = torch.load(self.ar_ckpt)
|
||||
if "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.ar
|
||||
elif name.startswith("ar"):
|
||||
self.ar = model
|
||||
state = torch.load(self.ar_ckpt)
|
||||
if "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.ar.load_state_dict(state)
|
||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.ar = parse( name, model, state )
|
||||
elif name.startswith("nar"):
|
||||
self.nar = model
|
||||
state = torch.load(self.nar_ckpt)
|
||||
if "symmap" in state:
|
||||
self.symmap = state['symmap']
|
||||
if "module" in state:
|
||||
state = state['module']
|
||||
self.nar.load_state_dict(state)
|
||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = parse( name, model, state )
|
||||
|
||||
if name.startswith("ar+nar"):
|
||||
self.nar = self.ar
|
||||
else:
|
||||
self.load_models()
|
||||
|
||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
if self.symmap is None:
|
||||
self.symmap = get_phone_symmap()
|
||||
|
||||
|
@ -98,13 +96,13 @@ class TTS():
|
|||
def load_models( self ):
|
||||
engines = load_engines()
|
||||
for name, engine in engines.items():
|
||||
if name[:6] == "ar+nar":
|
||||
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
if name.startswith("ar"):
|
||||
self.ar = engine.module
|
||||
elif name.startswith("nar"):
|
||||
self.nar = engine.module
|
||||
|
||||
if name.startswith("ar+nar"):
|
||||
self.nar = self.ar
|
||||
elif name[:2] == "ar":
|
||||
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
elif name[:3] == "nar":
|
||||
self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||
|
||||
def encode_text( self, text, language="en" ):
|
||||
# already a tensor, return it
|
||||
|
|
|
@ -41,12 +41,24 @@ class AR(Base):
|
|||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
return cfg.models.langs
|
||||
|
||||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
if cfg.mode == "training":
|
||||
return 0
|
||||
return cfg.inference.recurrent_chunk_size
|
||||
|
||||
"""
|
||||
@property
|
||||
def rotary_embedding_base(self) -> float:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.rotary_embedding_base
|
||||
return cfg.models.ar.rotary_embedding_base
|
||||
"""
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
|
|
@ -47,6 +47,14 @@ class AR_NAR(Base):
|
|||
def recurrent_chunk_size(self) -> int:
|
||||
return 0
|
||||
|
||||
"""
|
||||
@property
|
||||
def rotary_embedding_base(self) -> float:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.rotary_embedding_base
|
||||
return cfg.models.ar_nar.rotary_embedding_base
|
||||
"""
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
@ -293,6 +301,10 @@ def example_usage():
|
|||
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||
engine = Engine(model=model, optimizer=optimizer)
|
||||
|
||||
torch.save( {
|
||||
'module': model.state_dict()
|
||||
}, "./data/test.pth" )
|
||||
|
||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -235,11 +235,9 @@ class MultiEmbedding(nn.Module):
|
|||
|
||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||
class AudioEmbedding(nn.Module):
|
||||
def __init__(self, n_levels, n_tokens, token_dim):
|
||||
def __init__(self, l_tokens, token_dim):
|
||||
super().__init__()
|
||||
self.n_levels = n_levels
|
||||
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
||||
res_list = []
|
||||
|
@ -283,6 +281,10 @@ class Base(nn.Module):
|
|||
def n_max_levels(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def n_tasks(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
@ -290,6 +292,10 @@ class Base(nn.Module):
|
|||
@property
|
||||
def recurrent_chunk_size(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def rotary_embedding_base(self) -> float:
|
||||
return 10000
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
|
@ -341,17 +347,24 @@ class Base(nn.Module):
|
|||
self.n_layers = n_layers
|
||||
|
||||
# +1 to include the stop token
|
||||
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
|
||||
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
|
||||
n_prom_tokens = n_tokens
|
||||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||
|
||||
self.text_emb = Embedding(n_tokens, d_model)
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||
else:
|
||||
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
||||
# [1024] * 8
|
||||
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model)
|
||||
# [1025] + [1024] * 8
|
||||
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
|
||||
|
||||
# self.langs_emb = Embedding(self.n_langs, d_model)
|
||||
# self.tasks_emb = Embedding(self.n_tasks, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
|
@ -365,7 +378,6 @@ class Base(nn.Module):
|
|||
norm_type=self.norm_type,
|
||||
n_levels=self.n_resp_levels,
|
||||
) for _ in range(n_layers) ])
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
self.retnet = RetNetDecoder(RetNetConfig(
|
||||
vocab_size=n_tokens,
|
||||
|
@ -380,6 +392,8 @@ class Base(nn.Module):
|
|||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||
no_output_layer=True,
|
||||
decoder_normalize_before=True,
|
||||
|
||||
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||
))
|
||||
|
||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||
|
@ -407,12 +421,17 @@ class Base(nn.Module):
|
|||
resps_list: list[Tensor],
|
||||
targ_list: list[Tensor] | None = None,
|
||||
|
||||
#langs_list: list[Tensor] | None = None,
|
||||
#tasks_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: Tensor | None = None,
|
||||
state: dict | None = None,
|
||||
):
|
||||
x_list = self._samplewise_merge_tensors(
|
||||
self.text_emb(text_list),
|
||||
#self.langs_emb(langs_list),
|
||||
self.proms_emb(proms_list),
|
||||
#self.tasks_emb(tasks_list),
|
||||
self.resps_emb(resps_list, quant_levels),
|
||||
sep=self.sep,
|
||||
)
|
||||
|
@ -422,7 +441,7 @@ class Base(nn.Module):
|
|||
batch_size = len(text_list)
|
||||
device = x.device
|
||||
|
||||
if state is not None:
|
||||
if state is not None and self.arch_type == "retnet":
|
||||
# prefill
|
||||
if len(state) == 0:
|
||||
prefill_size = x.shape[1]
|
||||
|
@ -443,7 +462,6 @@ class Base(nn.Module):
|
|||
# pass our inputs through the transformer
|
||||
for block in self.blocks:
|
||||
x = block(x, m, l)
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
# pass our inputs through the RetNet
|
||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
|
|
|
@ -39,6 +39,10 @@ class NAR(Base):
|
|||
def n_tasks(self) -> int:
|
||||
return cfg.models.tasks
|
||||
|
||||
@property
|
||||
def n_langs(self) -> int:
|
||||
return cfg.models.langs
|
||||
|
||||
@property
|
||||
def version(self) -> int:
|
||||
if hasattr(self, "config") and self.config:
|
||||
|
@ -49,6 +53,14 @@ class NAR(Base):
|
|||
def recurrent_chunk_size(self) -> int:
|
||||
return 0
|
||||
|
||||
"""
|
||||
@property
|
||||
def rotary_embedding_base(self) -> float:
|
||||
if hasattr(self, "config") and self.config:
|
||||
return self.config.rotary_embedding_base
|
||||
return cfg.models.nar.rotary_embedding_base
|
||||
"""
|
||||
|
||||
@property
|
||||
def interleave(self) -> bool:
|
||||
return False
|
||||
|
|
|
@ -1,71 +1,3 @@
|
|||
"""
|
||||
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py
|
||||
# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Dict, Optional
|
||||
|
||||
from torch import Tensor
|
||||
|
||||
from torchscale.architecture.config import RetNetConfig
|
||||
from torchscale.architecture.retnet import RetNetDecoder
|
||||
|
||||
"""
|
||||
class FairseqIncrementalState(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.init_incremental_state()
|
||||
|
||||
def init_incremental_state(self):
|
||||
self._incremental_state_id = str(uuid.uuid4())
|
||||
|
||||
def _get_full_incremental_state_key(self, key: str) -> str:
|
||||
return "{}.{}".format(self._incremental_state_id, key)
|
||||
|
||||
def get_incremental_state(
|
||||
self,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
key: str,
|
||||
) -> Optional[Dict[str, Optional[Tensor]]]:
|
||||
full_key = self._get_full_incremental_state_key(key)
|
||||
if incremental_state is None or full_key not in incremental_state:
|
||||
return None
|
||||
return incremental_state[full_key]
|
||||
|
||||
def set_incremental_state(
|
||||
self,
|
||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
||||
key: str,
|
||||
value: Dict[str, Optional[Tensor]],
|
||||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
||||
if incremental_state is not None:
|
||||
full_key = self._get_full_incremental_state_key(key)
|
||||
incremental_state[full_key] = value
|
||||
return incremental_state
|
||||
|
||||
|
||||
def with_incremental_state(cls):
|
||||
cls.__bases__ = (FairseqIncrementalState,) + tuple(
|
||||
b for b in cls.__bases__ if b != FairseqIncrementalState
|
||||
)
|
||||
return cls
|
||||
|
||||
|
||||
from torchscale.architecture.config import RetNetConfig
|
||||
from torchscale.architecture.retnet import RetNetDecoder as Decoder
|
||||
|
||||
@with_incremental_state
|
||||
class RetNetDecoder(Decoder):
|
||||
def forward(self, src_tokens, **kwargs):
|
||||
return super().forward(src_tokens, **kwargs)
|
||||
|
||||
def max_positions(self):
|
||||
return self.args.max_token_positions
|
||||
|
||||
def reorder_incremental_state( self, incremental_state, new_order ):
|
||||
for module in incremental_state:
|
||||
for key in incremental_state[module]:
|
||||
result = incremental_state[module][key].index_select(0, new_order)
|
||||
incremental_state[module][key] = result
|
||||
"""
|
||||
# from retnet import RetNet
|
|
@ -64,84 +64,57 @@ def load_engines(invert=False):
|
|||
lr_scheduler = None
|
||||
|
||||
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||
optimizer_class = None
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
"betas": (0.9, 0.96),
|
||||
"eps": 1e-07,
|
||||
"weight_decay": 0.01,
|
||||
}
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = ml.AdamW(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
params["betas"] = (0.9, 0.96)
|
||||
params["eps"] = 1e-07
|
||||
params["weight_decay"] = 0.01
|
||||
|
||||
optimizer_class = ml.AdamW
|
||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = ml.SGD(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
optimizer = ml.SGD
|
||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||
params = {
|
||||
"lr": cfg.hyperparameters.learning_rate,
|
||||
}
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = ml.Prodigy(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
optimizer_class = ml.Prodigy
|
||||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
optimizer = optimizer_class(
|
||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||
**params,
|
||||
)
|
||||
|
||||
# set up our LR scheduler here
|
||||
|
||||
if not model._cfg.training:
|
||||
optimizer = None
|
||||
lr_scheduler = None
|
||||
|
||||
stats = None
|
||||
if cfg.trainer.load_state_dict or not model._cfg.training:
|
||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
||||
# exporting with vall_e.export exports the state dict under .module
|
||||
|
||||
# state dict is not just the module, extract the extra trainer details
|
||||
if "stats" in state:
|
||||
additionals = state["stats"]
|
||||
|
||||
if "module" in state:
|
||||
state = state["module"]
|
||||
|
||||
# should decouple the following from this trainer script
|
||||
# probably with passing a fun that defaults to a lambda x: x deal
|
||||
|
||||
"""
|
||||
# can probably be done a lot more intelligently but oh well
|
||||
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
||||
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
|
||||
|
||||
# copy weights from the dict into the old portion
|
||||
model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :]
|
||||
# copy the full tensors back
|
||||
state['proms_emb.weight'] = model.proms_emb.weight
|
||||
|
||||
# extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
||||
if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]:
|
||||
o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape
|
||||
n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape
|
||||
|
||||
# copy weights from the dict into the old portion
|
||||
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
|
||||
# copy the full tensors back
|
||||
state['resps_emb.weight'] = model.resps_emb.weight
|
||||
"""
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
||||
# use base engine because DeepSpeed memory leaks
|
||||
# use base engine because DeepSpeed memory leaks if it's a non-training model
|
||||
engines[name] = (Engine if model._cfg.training else _Engine)(
|
||||
#engines[name] = Engine(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
|
||||
_cfg=model._cfg,
|
||||
stats=stats
|
||||
)
|
||||
|
||||
engines = Engines(engines)
|
||||
|
|
Loading…
Reference in New Issue
Block a user