restructured some things with the model to remove dead weights

This commit is contained in:
mrq 2023-09-20 19:10:59 -05:00
parent a6bfe43590
commit c0b25541e3
11 changed files with 166 additions and 191 deletions

View File

@ -57,7 +57,7 @@ setup(
"auraloss[all]",
"vocos",
"h5py",
"torchscale @ git+https://github.com/microsoft/torchscale",
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
],
url="https://git.ecker.tech/mrq/vall-e",
)

View File

@ -157,16 +157,16 @@ class Dataset:
@dataclass()
class Model:
name: str = ""
version: int = 1
size: str | float | dict = "full"
resp_levels: int = 1
prom_levels: int = 8
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
langs: int = 0 # defined languages
arch_type: str = "retnet"
training: bool = True
interleave: bool = False
frozen_params: list[str] = field(default_factory=lambda: [])
p_ar_nar: float = 0.5
version: int = 1
@property
def full_name(self):
@ -240,8 +240,8 @@ class Models:
_prom_levels: int = 1
_models: list[Model] = field(default_factory=lambda: [
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False),
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
])
def get(self, name=None):

View File

@ -59,10 +59,10 @@ class Engine():
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
self.global_steps = 0
self.micro_steps = 0
self.global_samples = 0
self.tokens_processed = 0
self.global_steps = kwargs.pop("global_steps", 0)
self.micro_steps = kwargs.pop("micro_steps", 0)
self.global_samples = kwargs.pop("global_samples", 0)
self.tokens_processed = kwargs.pop("tokens_processed", 0)
self._frozen_params = set()
@ -117,10 +117,12 @@ class Engine():
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
"global_step": self.global_step,
"micro_step": self.micro_step,
"global_samples": self.global_samples,
"tokens_processed": self.tokens_processed,
"stats": {
"global_step": self.global_step,
"micro_step": self.micro_step,
"global_samples": self.global_samples,
"tokens_processed": self.tokens_processed,
}
}, save_path)
open(save_dir / "latest", 'w').write( tag )
@ -137,10 +139,10 @@ class Engine():
return
state = torch.load(load_path, map_location=torch.device(cfg.device))
self.global_steps = state['global_step']
self.micro_steps = state['micro_step']
self.global_samples = state['global_samples']
self.tokens_processed = state['tokens_processed']
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
self.module.load_state_dict(state['module'])
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
@ -261,12 +263,14 @@ class Engines(dict[str, Engine]):
outpath = cfg.ckpt_dir / name / "fp32.pth"
state_dict = {
'module': engine.module.state_dict(),
"global_step": engine.global_step,
"micro_step": engine.micro_step,
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
"stats": {
"global_step": engine.global_step,
"micro_step": engine.micro_step,
"global_samples": engine.global_samples,
"tokens_processed": engine.tokens_processed,
},
"userdata": userdata
}
state_dict.update(userdata)
torch.save(state_dict, outpath)
print(f"Exported {name} to {outpath}")

View File

@ -39,10 +39,24 @@ class Engine(DeepSpeedEngine):
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
if "stats" in kwargs:
# stats COULD be = None
stats = kwargs.pop('stats')
if stats is None:
stats = {
"global_steps": 0,
"micro_steps": 0,
"global_samples": 0,
"tokens_processed": 0,
}
super().__init__(None, *args, **kwargs)
self._frozen_params = set()
self.tokens_processed = 0
self.global_steps = stats["global_steps"]
self.micro_steps = stats["micro_steps"]
self.global_samples = stats["global_samples"]
self.tokens_processed = stats["tokens_processed"]
def freeze(self, freeze_all=True):
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):

View File

@ -34,7 +34,7 @@ class TTS():
if amp is None:
amp = cfg.inference.amp
if dtype is None:
dtype = cfg.inference.dtype
dtype = cfg.inference.weight_dtype
if device is None:
device = cfg.device
@ -50,43 +50,41 @@ class TTS():
self.amp = amp
self.symmap = None
def parse( name, model, state ):
if "userdata" in state and 'symmap' in state['userdata']:
self.symmap = state['userdata']['symmap']
elif "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
model.load_state_dict(state)
return model
if ar_ckpt and nar_ckpt:
self.ar_ckpt = ar_ckpt
self.nar_ckpt = nar_ckpt
models = get_models(cfg.models.get())
for name, model in models.items():
if name.startswith("ar+nar"):
self.ar = model
if name.startswith("ar"):
state = torch.load(self.ar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.ar
elif name.startswith("ar"):
self.ar = model
state = torch.load(self.ar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.ar.load_state_dict(state)
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.ar = parse( name, model, state )
elif name.startswith("nar"):
self.nar = model
state = torch.load(self.nar_ckpt)
if "symmap" in state:
self.symmap = state['symmap']
if "module" in state:
state = state['module']
self.nar.load_state_dict(state)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = parse( name, model, state )
if name.startswith("ar+nar"):
self.nar = self.ar
else:
self.load_models()
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
if self.symmap is None:
self.symmap = get_phone_symmap()
@ -98,13 +96,13 @@ class TTS():
def load_models( self ):
engines = load_engines()
for name, engine in engines.items():
if name[:6] == "ar+nar":
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
if name.startswith("ar"):
self.ar = engine.module
elif name.startswith("nar"):
self.nar = engine.module
if name.startswith("ar+nar"):
self.nar = self.ar
elif name[:2] == "ar":
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
elif name[:3] == "nar":
self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
def encode_text( self, text, language="en" ):
# already a tensor, return it

View File

@ -41,12 +41,24 @@ class AR(Base):
def n_tasks(self) -> int:
return cfg.models.tasks
@property
def n_langs(self) -> int:
return cfg.models.langs
@property
def recurrent_chunk_size(self) -> int:
if cfg.mode == "training":
return 0
return cfg.inference.recurrent_chunk_size
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.ar.rotary_embedding_base
"""
@property
def interleave(self) -> bool:
if hasattr(self, "config") and self.config:

View File

@ -47,6 +47,14 @@ class AR_NAR(Base):
def recurrent_chunk_size(self) -> int:
return 0
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.ar_nar.rotary_embedding_base
"""
@property
def interleave(self) -> bool:
return False
@ -293,6 +301,10 @@ def example_usage():
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
engine = Engine(model=model, optimizer=optimizer)
torch.save( {
'module': model.state_dict()
}, "./data/test.pth" )
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()

View File

@ -235,11 +235,9 @@ class MultiEmbedding(nn.Module):
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
class AudioEmbedding(nn.Module):
def __init__(self, n_levels, n_tokens, token_dim):
def __init__(self, l_tokens, token_dim):
super().__init__()
self.n_levels = n_levels
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
res_list = []
@ -283,6 +281,10 @@ class Base(nn.Module):
def n_max_levels(self) -> int:
raise NotImplementedError
@property
def n_langs(self) -> int:
raise NotImplementedError
@property
def n_tasks(self) -> int:
raise NotImplementedError
@ -290,6 +292,10 @@ class Base(nn.Module):
@property
def recurrent_chunk_size(self) -> int:
raise NotImplementedError
@property
def rotary_embedding_base(self) -> float:
return 10000
@property
def interleave(self) -> bool:
@ -341,17 +347,24 @@ class Base(nn.Module):
self.n_layers = n_layers
# +1 to include the stop token
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
n_prom_tokens = n_tokens
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
self.text_emb = Embedding(n_tokens, d_model)
if self.version == 1: # legacy
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
else:
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
# [1024] * 8
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model)
# [1025] + [1024] * 8
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
# self.langs_emb = Embedding(self.n_langs, d_model)
# self.tasks_emb = Embedding(self.n_tasks, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
@ -365,7 +378,6 @@ class Base(nn.Module):
norm_type=self.norm_type,
n_levels=self.n_resp_levels,
) for _ in range(n_layers) ])
elif self.arch_type == "retnet":
self.retnet = RetNetDecoder(RetNetConfig(
vocab_size=n_tokens,
@ -380,6 +392,8 @@ class Base(nn.Module):
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
no_output_layer=True,
decoder_normalize_before=True,
rotary_embedding_base=self.rotary_embedding_base, # 10000
))
self.classifier = nn.Linear(d_model, n_resp_tokens)
@ -407,12 +421,17 @@ class Base(nn.Module):
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
#langs_list: list[Tensor] | None = None,
#tasks_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
state: dict | None = None,
):
x_list = self._samplewise_merge_tensors(
self.text_emb(text_list),
#self.langs_emb(langs_list),
self.proms_emb(proms_list),
#self.tasks_emb(tasks_list),
self.resps_emb(resps_list, quant_levels),
sep=self.sep,
)
@ -422,7 +441,7 @@ class Base(nn.Module):
batch_size = len(text_list)
device = x.device
if state is not None:
if state is not None and self.arch_type == "retnet":
# prefill
if len(state) == 0:
prefill_size = x.shape[1]
@ -443,7 +462,6 @@ class Base(nn.Module):
# pass our inputs through the transformer
for block in self.blocks:
x = block(x, m, l)
elif self.arch_type == "retnet":
# pass our inputs through the RetNet
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)

View File

@ -39,6 +39,10 @@ class NAR(Base):
def n_tasks(self) -> int:
return cfg.models.tasks
@property
def n_langs(self) -> int:
return cfg.models.langs
@property
def version(self) -> int:
if hasattr(self, "config") and self.config:
@ -49,6 +53,14 @@ class NAR(Base):
def recurrent_chunk_size(self) -> int:
return 0
"""
@property
def rotary_embedding_base(self) -> float:
if hasattr(self, "config") and self.config:
return self.config.rotary_embedding_base
return cfg.models.nar.rotary_embedding_base
"""
@property
def interleave(self) -> bool:
return False

View File

@ -1,71 +1,3 @@
"""
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py
# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this
"""
import uuid
from typing import Dict, Optional
from torch import Tensor
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder
"""
class FairseqIncrementalState(object):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.init_incremental_state()
def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())
def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)
def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state
def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(
b for b in cls.__bases__ if b != FairseqIncrementalState
)
return cls
from torchscale.architecture.config import RetNetConfig
from torchscale.architecture.retnet import RetNetDecoder as Decoder
@with_incremental_state
class RetNetDecoder(Decoder):
def forward(self, src_tokens, **kwargs):
return super().forward(src_tokens, **kwargs)
def max_positions(self):
return self.args.max_token_positions
def reorder_incremental_state( self, incremental_state, new_order ):
for module in incremental_state:
for key in incremental_state[module]:
result = incremental_state[module][key].index_select(0, new_order)
incremental_state[module][key] = result
"""
# from retnet import RetNet

View File

@ -64,84 +64,57 @@ def load_engines(invert=False):
lr_scheduler = None
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
optimizer_class = None
params = {
"lr": cfg.hyperparameters.learning_rate,
}
if cfg.hyperparameters.optimizer.lower() == "adamw":
params = {
"lr": cfg.hyperparameters.learning_rate,
"betas": (0.9, 0.96),
"eps": 1e-07,
"weight_decay": 0.01,
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.AdamW(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
params["betas"] = (0.9, 0.96)
params["eps"] = 1e-07
params["weight_decay"] = 0.01
optimizer_class = ml.AdamW
elif cfg.hyperparameters.optimizer.lower() == "sgd":
params = {
"lr": cfg.hyperparameters.learning_rate,
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.SGD(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
optimizer = ml.SGD
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
params = {
"lr": cfg.hyperparameters.learning_rate,
}
params.update(cfg.hyperparameters.optimizer_params)
optimizer = ml.Prodigy(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
optimizer_class = ml.Prodigy
else:
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
params.update(cfg.hyperparameters.optimizer_params)
optimizer = optimizer_class(
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
**params,
)
# set up our LR scheduler here
if not model._cfg.training:
optimizer = None
lr_scheduler = None
stats = None
if cfg.trainer.load_state_dict or not model._cfg.training:
load_path = cfg.ckpt_dir / name / "fp32.pth"
state = torch.load(load_path, map_location=torch.device(cfg.device))
# exporting the model from the zero_to_fp32.py exports the actual module's dict
# exporting with vall_e.export exports the state dict under .module
# state dict is not just the module, extract the extra trainer details
if "stats" in state:
additionals = state["stats"]
if "module" in state:
state = state["module"]
# should decouple the following from this trainer script
# probably with passing a fun that defaults to a lambda x: x deal
"""
# can probably be done a lot more intelligently but oh well
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
# copy weights from the dict into the old portion
model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :]
# copy the full tensors back
state['proms_emb.weight'] = model.proms_emb.weight
# extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]:
o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape
n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape
# copy weights from the dict into the old portion
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
# copy the full tensors back
state['resps_emb.weight'] = model.resps_emb.weight
"""
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
# use base engine because DeepSpeed memory leaks
# use base engine because DeepSpeed memory leaks if it's a non-training model
engines[name] = (Engine if model._cfg.training else _Engine)(
#engines[name] = Engine(
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
_cfg=model._cfg,
stats=stats
)
engines = Engines(engines)