restructured some things with the model to remove dead weights
This commit is contained in:
parent
a6bfe43590
commit
c0b25541e3
2
setup.py
2
setup.py
|
@ -57,7 +57,7 @@ setup(
|
||||||
"auraloss[all]",
|
"auraloss[all]",
|
||||||
"vocos",
|
"vocos",
|
||||||
"h5py",
|
"h5py",
|
||||||
"torchscale @ git+https://github.com/microsoft/torchscale",
|
"torchscale @ git+https://git.ecker.tech/mrq/torchscale",
|
||||||
],
|
],
|
||||||
url="https://git.ecker.tech/mrq/vall-e",
|
url="https://git.ecker.tech/mrq/vall-e",
|
||||||
)
|
)
|
||||||
|
|
|
@ -157,16 +157,16 @@ class Dataset:
|
||||||
@dataclass()
|
@dataclass()
|
||||||
class Model:
|
class Model:
|
||||||
name: str = ""
|
name: str = ""
|
||||||
|
version: int = 1
|
||||||
size: str | float | dict = "full"
|
size: str | float | dict = "full"
|
||||||
resp_levels: int = 1
|
resp_levels: int = 1
|
||||||
prom_levels: int = 8
|
prom_levels: int = 8
|
||||||
tasks: int = 8 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
tasks: int = 0 # ["tts", "ns", "sr", "tse", "cse", "nse"] and leaves two more for anything else I want (like "svc")
|
||||||
|
langs: int = 0 # defined languages
|
||||||
arch_type: str = "retnet"
|
arch_type: str = "retnet"
|
||||||
training: bool = True
|
training: bool = True
|
||||||
interleave: bool = False
|
interleave: bool = False
|
||||||
frozen_params: list[str] = field(default_factory=lambda: [])
|
frozen_params: list[str] = field(default_factory=lambda: [])
|
||||||
p_ar_nar: float = 0.5
|
|
||||||
version: int = 1
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def full_name(self):
|
def full_name(self):
|
||||||
|
@ -240,8 +240,8 @@ class Models:
|
||||||
_prom_levels: int = 1
|
_prom_levels: int = 1
|
||||||
|
|
||||||
_models: list[Model] = field(default_factory=lambda: [
|
_models: list[Model] = field(default_factory=lambda: [
|
||||||
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, training=True, interleave=False),
|
Model(name="ar", resp_levels=1, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
|
||||||
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, training=True, interleave=False),
|
Model(name="nar", resp_levels=7, prom_levels=8, tasks=8, langs=1, training=True, interleave=False),
|
||||||
])
|
])
|
||||||
|
|
||||||
def get(self, name=None):
|
def get(self, name=None):
|
||||||
|
|
|
@ -59,10 +59,10 @@ class Engine():
|
||||||
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
self.optimizer = kwargs['optimizer'] if 'optimizer' in kwargs else None
|
||||||
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
self.lr_scheduler = kwargs['lr_scheduler'] if 'lr_scheduler' in kwargs else None
|
||||||
|
|
||||||
self.global_steps = 0
|
self.global_steps = kwargs.pop("global_steps", 0)
|
||||||
self.micro_steps = 0
|
self.micro_steps = kwargs.pop("micro_steps", 0)
|
||||||
self.global_samples = 0
|
self.global_samples = kwargs.pop("global_samples", 0)
|
||||||
self.tokens_processed = 0
|
self.tokens_processed = kwargs.pop("tokens_processed", 0)
|
||||||
|
|
||||||
self._frozen_params = set()
|
self._frozen_params = set()
|
||||||
|
|
||||||
|
@ -117,10 +117,12 @@ class Engine():
|
||||||
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
"optimizer": self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||||
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
"lr_scheduler": self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
|
||||||
|
|
||||||
"global_step": self.global_step,
|
"stats": {
|
||||||
"micro_step": self.micro_step,
|
"global_step": self.global_step,
|
||||||
"global_samples": self.global_samples,
|
"micro_step": self.micro_step,
|
||||||
"tokens_processed": self.tokens_processed,
|
"global_samples": self.global_samples,
|
||||||
|
"tokens_processed": self.tokens_processed,
|
||||||
|
}
|
||||||
}, save_path)
|
}, save_path)
|
||||||
|
|
||||||
open(save_dir / "latest", 'w').write( tag )
|
open(save_dir / "latest", 'w').write( tag )
|
||||||
|
@ -137,10 +139,10 @@ class Engine():
|
||||||
return
|
return
|
||||||
|
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
self.global_steps = state['global_step']
|
self.global_steps = state['stats']['global_step'] if 'stats' in state else state['global_step']
|
||||||
self.micro_steps = state['micro_step']
|
self.micro_steps = state['stats']['micro_step'] if 'stats' in state else state['micro_step']
|
||||||
self.global_samples = state['global_samples']
|
self.global_samples = state['stats']['global_samples'] if 'stats' in state else state['global_samples']
|
||||||
self.tokens_processed = state['tokens_processed']
|
self.tokens_processed = state['stats']['tokens_processed'] if 'stats' in state else state['tokens_processed']
|
||||||
self.module.load_state_dict(state['module'])
|
self.module.load_state_dict(state['module'])
|
||||||
|
|
||||||
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
load_optimizer_states = load_optimizer_states and self.optimizer is not None and 'optimizer' in state
|
||||||
|
@ -261,12 +263,14 @@ class Engines(dict[str, Engine]):
|
||||||
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
outpath = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
state_dict = {
|
state_dict = {
|
||||||
'module': engine.module.state_dict(),
|
'module': engine.module.state_dict(),
|
||||||
"global_step": engine.global_step,
|
"stats": {
|
||||||
"micro_step": engine.micro_step,
|
"global_step": engine.global_step,
|
||||||
"global_samples": engine.global_samples,
|
"micro_step": engine.micro_step,
|
||||||
"tokens_processed": engine.tokens_processed,
|
"global_samples": engine.global_samples,
|
||||||
|
"tokens_processed": engine.tokens_processed,
|
||||||
|
},
|
||||||
|
"userdata": userdata
|
||||||
}
|
}
|
||||||
state_dict.update(userdata)
|
|
||||||
torch.save(state_dict, outpath)
|
torch.save(state_dict, outpath)
|
||||||
print(f"Exported {name} to {outpath}")
|
print(f"Exported {name} to {outpath}")
|
||||||
|
|
||||||
|
|
|
@ -39,10 +39,24 @@ class Engine(DeepSpeedEngine):
|
||||||
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
|
kwargs['config'] = cfg.trainer.deepspeed.ds_cfg
|
||||||
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
kwargs['config_class'] = DeepSpeedConfig(kwargs['config'])
|
||||||
|
|
||||||
|
if "stats" in kwargs:
|
||||||
|
# stats COULD be = None
|
||||||
|
stats = kwargs.pop('stats')
|
||||||
|
if stats is None:
|
||||||
|
stats = {
|
||||||
|
"global_steps": 0,
|
||||||
|
"micro_steps": 0,
|
||||||
|
"global_samples": 0,
|
||||||
|
"tokens_processed": 0,
|
||||||
|
}
|
||||||
|
|
||||||
super().__init__(None, *args, **kwargs)
|
super().__init__(None, *args, **kwargs)
|
||||||
self._frozen_params = set()
|
self._frozen_params = set()
|
||||||
|
|
||||||
self.tokens_processed = 0
|
self.global_steps = stats["global_steps"]
|
||||||
|
self.micro_steps = stats["micro_steps"]
|
||||||
|
self.global_samples = stats["global_samples"]
|
||||||
|
self.tokens_processed = stats["tokens_processed"]
|
||||||
|
|
||||||
def freeze(self, freeze_all=True):
|
def freeze(self, freeze_all=True):
|
||||||
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
if self._cfg is None or not hasattr(self._cfg, "frozen_params"):
|
||||||
|
|
|
@ -34,7 +34,7 @@ class TTS():
|
||||||
if amp is None:
|
if amp is None:
|
||||||
amp = cfg.inference.amp
|
amp = cfg.inference.amp
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = cfg.inference.dtype
|
dtype = cfg.inference.weight_dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
device = cfg.device
|
device = cfg.device
|
||||||
|
|
||||||
|
@ -50,43 +50,41 @@ class TTS():
|
||||||
self.amp = amp
|
self.amp = amp
|
||||||
|
|
||||||
self.symmap = None
|
self.symmap = None
|
||||||
|
|
||||||
|
def parse( name, model, state ):
|
||||||
|
if "userdata" in state and 'symmap' in state['userdata']:
|
||||||
|
self.symmap = state['userdata']['symmap']
|
||||||
|
elif "symmap" in state:
|
||||||
|
self.symmap = state['symmap']
|
||||||
|
|
||||||
|
if "module" in state:
|
||||||
|
state = state['module']
|
||||||
|
|
||||||
|
model.load_state_dict(state)
|
||||||
|
return model
|
||||||
|
|
||||||
if ar_ckpt and nar_ckpt:
|
if ar_ckpt and nar_ckpt:
|
||||||
self.ar_ckpt = ar_ckpt
|
self.ar_ckpt = ar_ckpt
|
||||||
self.nar_ckpt = nar_ckpt
|
self.nar_ckpt = nar_ckpt
|
||||||
|
|
||||||
models = get_models(cfg.models.get())
|
models = get_models(cfg.models.get())
|
||||||
|
|
||||||
for name, model in models.items():
|
for name, model in models.items():
|
||||||
if name.startswith("ar+nar"):
|
if name.startswith("ar"):
|
||||||
self.ar = model
|
|
||||||
state = torch.load(self.ar_ckpt)
|
state = torch.load(self.ar_ckpt)
|
||||||
if "symmap" in state:
|
self.ar = parse( name, model, state )
|
||||||
self.symmap = state['symmap']
|
|
||||||
if "module" in state:
|
|
||||||
state = state['module']
|
|
||||||
self.ar.load_state_dict(state)
|
|
||||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
||||||
self.nar = self.ar
|
|
||||||
elif name.startswith("ar"):
|
|
||||||
self.ar = model
|
|
||||||
state = torch.load(self.ar_ckpt)
|
|
||||||
if "symmap" in state:
|
|
||||||
self.symmap = state['symmap']
|
|
||||||
if "module" in state:
|
|
||||||
state = state['module']
|
|
||||||
self.ar.load_state_dict(state)
|
|
||||||
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
||||||
elif name.startswith("nar"):
|
elif name.startswith("nar"):
|
||||||
self.nar = model
|
|
||||||
state = torch.load(self.nar_ckpt)
|
state = torch.load(self.nar_ckpt)
|
||||||
if "symmap" in state:
|
self.nar = parse( name, model, state )
|
||||||
self.symmap = state['symmap']
|
|
||||||
if "module" in state:
|
if name.startswith("ar+nar"):
|
||||||
state = state['module']
|
self.nar = self.ar
|
||||||
self.nar.load_state_dict(state)
|
|
||||||
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
||||||
else:
|
else:
|
||||||
self.load_models()
|
self.load_models()
|
||||||
|
|
||||||
|
self.ar = self.ar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
|
self.nar = self.nar.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
||||||
|
|
||||||
if self.symmap is None:
|
if self.symmap is None:
|
||||||
self.symmap = get_phone_symmap()
|
self.symmap = get_phone_symmap()
|
||||||
|
|
||||||
|
@ -98,13 +96,13 @@ class TTS():
|
||||||
def load_models( self ):
|
def load_models( self ):
|
||||||
engines = load_engines()
|
engines = load_engines()
|
||||||
for name, engine in engines.items():
|
for name, engine in engines.items():
|
||||||
if name[:6] == "ar+nar":
|
if name.startswith("ar"):
|
||||||
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
self.ar = engine.module
|
||||||
|
elif name.startswith("nar"):
|
||||||
|
self.nar = engine.module
|
||||||
|
|
||||||
|
if name.startswith("ar+nar"):
|
||||||
self.nar = self.ar
|
self.nar = self.ar
|
||||||
elif name[:2] == "ar":
|
|
||||||
self.ar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
||||||
elif name[:3] == "nar":
|
|
||||||
self.nar = engine.module.to(self.device, dtype=self.dtype if not self.amp else torch.float32)
|
|
||||||
|
|
||||||
def encode_text( self, text, language="en" ):
|
def encode_text( self, text, language="en" ):
|
||||||
# already a tensor, return it
|
# already a tensor, return it
|
||||||
|
|
|
@ -41,12 +41,24 @@ class AR(Base):
|
||||||
def n_tasks(self) -> int:
|
def n_tasks(self) -> int:
|
||||||
return cfg.models.tasks
|
return cfg.models.tasks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_langs(self) -> int:
|
||||||
|
return cfg.models.langs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def recurrent_chunk_size(self) -> int:
|
def recurrent_chunk_size(self) -> int:
|
||||||
if cfg.mode == "training":
|
if cfg.mode == "training":
|
||||||
return 0
|
return 0
|
||||||
return cfg.inference.recurrent_chunk_size
|
return cfg.inference.recurrent_chunk_size
|
||||||
|
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
def rotary_embedding_base(self) -> float:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.rotary_embedding_base
|
||||||
|
return cfg.models.ar.rotary_embedding_base
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interleave(self) -> bool:
|
def interleave(self) -> bool:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
|
|
@ -47,6 +47,14 @@ class AR_NAR(Base):
|
||||||
def recurrent_chunk_size(self) -> int:
|
def recurrent_chunk_size(self) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
def rotary_embedding_base(self) -> float:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.rotary_embedding_base
|
||||||
|
return cfg.models.ar_nar.rotary_embedding_base
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interleave(self) -> bool:
|
def interleave(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
@ -293,6 +301,10 @@ def example_usage():
|
||||||
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
|
torch.save( {
|
||||||
|
'module': model.state_dict()
|
||||||
|
}, "./data/test.pth" )
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
|
|
|
@ -235,11 +235,9 @@ class MultiEmbedding(nn.Module):
|
||||||
|
|
||||||
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
# Embedding that sums each RVQ-bin level within a given input acoustic prompt
|
||||||
class AudioEmbedding(nn.Module):
|
class AudioEmbedding(nn.Module):
|
||||||
def __init__(self, n_levels, n_tokens, token_dim):
|
def __init__(self, l_tokens, token_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_levels = n_levels
|
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for n_tokens in l_tokens])
|
||||||
# would it be better to have embeddings[1:] reduced to 1024 tokens to attend to, so it's *not* factoring in the stop token?
|
|
||||||
self.embeddings = nn.ModuleList([nn.Embedding(n_tokens, token_dim) for _ in range(self.n_levels)])
|
|
||||||
|
|
||||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None ) -> list[Tensor]:
|
||||||
res_list = []
|
res_list = []
|
||||||
|
@ -283,6 +281,10 @@ class Base(nn.Module):
|
||||||
def n_max_levels(self) -> int:
|
def n_max_levels(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_langs(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_tasks(self) -> int:
|
def n_tasks(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -290,6 +292,10 @@ class Base(nn.Module):
|
||||||
@property
|
@property
|
||||||
def recurrent_chunk_size(self) -> int:
|
def recurrent_chunk_size(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rotary_embedding_base(self) -> float:
|
||||||
|
return 10000
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interleave(self) -> bool:
|
def interleave(self) -> bool:
|
||||||
|
@ -341,17 +347,24 @@ class Base(nn.Module):
|
||||||
self.n_layers = n_layers
|
self.n_layers = n_layers
|
||||||
|
|
||||||
# +1 to include the stop token
|
# +1 to include the stop token
|
||||||
n_prom_tokens = n_tokens + (self.n_tasks - 1) # - 1 because tts is an inherent task
|
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
|
||||||
|
n_prom_tokens = n_tokens
|
||||||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||||
|
|
||||||
self.text_emb = Embedding(n_tokens, d_model)
|
self.text_emb = Embedding(n_tokens, d_model)
|
||||||
|
|
||||||
if self.version == 1: # legacy
|
if self.version == 1: # legacy
|
||||||
|
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||||
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
self.proms_emb = MultiEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
||||||
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
self.resps_emb = MultiEmbedding(self.n_resp_levels, n_resp_tokens, d_model, monolithic=self.monolithic)
|
||||||
else:
|
else:
|
||||||
self.proms_emb = AudioEmbedding(self.n_prom_levels, n_prom_tokens, d_model)
|
# [1024] * 8
|
||||||
self.resps_emb = AudioEmbedding(self.n_resp_levels, n_resp_tokens, d_model)
|
self.proms_emb = AudioEmbedding([n_prom_tokens] * self.n_prom_levels, d_model)
|
||||||
|
# [1025] + [1024] * 8
|
||||||
|
self.resps_emb = AudioEmbedding([n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1), d_model)
|
||||||
|
|
||||||
|
# self.langs_emb = Embedding(self.n_langs, d_model)
|
||||||
|
# self.tasks_emb = Embedding(self.n_tasks, d_model)
|
||||||
|
|
||||||
self.sep = nn.Parameter(torch.randn(d_model))
|
self.sep = nn.Parameter(torch.randn(d_model))
|
||||||
|
|
||||||
|
@ -365,7 +378,6 @@ class Base(nn.Module):
|
||||||
norm_type=self.norm_type,
|
norm_type=self.norm_type,
|
||||||
n_levels=self.n_resp_levels,
|
n_levels=self.n_resp_levels,
|
||||||
) for _ in range(n_layers) ])
|
) for _ in range(n_layers) ])
|
||||||
|
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
self.retnet = RetNetDecoder(RetNetConfig(
|
self.retnet = RetNetDecoder(RetNetConfig(
|
||||||
vocab_size=n_tokens,
|
vocab_size=n_tokens,
|
||||||
|
@ -380,6 +392,8 @@ class Base(nn.Module):
|
||||||
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
recurrent_chunkwise_size=self.recurrent_chunk_size if self.causal else 0,
|
||||||
no_output_layer=True,
|
no_output_layer=True,
|
||||||
decoder_normalize_before=True,
|
decoder_normalize_before=True,
|
||||||
|
|
||||||
|
rotary_embedding_base=self.rotary_embedding_base, # 10000
|
||||||
))
|
))
|
||||||
|
|
||||||
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
self.classifier = nn.Linear(d_model, n_resp_tokens)
|
||||||
|
@ -407,12 +421,17 @@ class Base(nn.Module):
|
||||||
resps_list: list[Tensor],
|
resps_list: list[Tensor],
|
||||||
targ_list: list[Tensor] | None = None,
|
targ_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
|
#langs_list: list[Tensor] | None = None,
|
||||||
|
#tasks_list: list[Tensor] | None = None,
|
||||||
|
|
||||||
quant_levels: Tensor | None = None,
|
quant_levels: Tensor | None = None,
|
||||||
state: dict | None = None,
|
state: dict | None = None,
|
||||||
):
|
):
|
||||||
x_list = self._samplewise_merge_tensors(
|
x_list = self._samplewise_merge_tensors(
|
||||||
self.text_emb(text_list),
|
self.text_emb(text_list),
|
||||||
|
#self.langs_emb(langs_list),
|
||||||
self.proms_emb(proms_list),
|
self.proms_emb(proms_list),
|
||||||
|
#self.tasks_emb(tasks_list),
|
||||||
self.resps_emb(resps_list, quant_levels),
|
self.resps_emb(resps_list, quant_levels),
|
||||||
sep=self.sep,
|
sep=self.sep,
|
||||||
)
|
)
|
||||||
|
@ -422,7 +441,7 @@ class Base(nn.Module):
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
device = x.device
|
device = x.device
|
||||||
|
|
||||||
if state is not None:
|
if state is not None and self.arch_type == "retnet":
|
||||||
# prefill
|
# prefill
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
prefill_size = x.shape[1]
|
prefill_size = x.shape[1]
|
||||||
|
@ -443,7 +462,6 @@ class Base(nn.Module):
|
||||||
# pass our inputs through the transformer
|
# pass our inputs through the transformer
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x, m, l)
|
x = block(x, m, l)
|
||||||
|
|
||||||
elif self.arch_type == "retnet":
|
elif self.arch_type == "retnet":
|
||||||
# pass our inputs through the RetNet
|
# pass our inputs through the RetNet
|
||||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||||
|
|
|
@ -39,6 +39,10 @@ class NAR(Base):
|
||||||
def n_tasks(self) -> int:
|
def n_tasks(self) -> int:
|
||||||
return cfg.models.tasks
|
return cfg.models.tasks
|
||||||
|
|
||||||
|
@property
|
||||||
|
def n_langs(self) -> int:
|
||||||
|
return cfg.models.langs
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def version(self) -> int:
|
def version(self) -> int:
|
||||||
if hasattr(self, "config") and self.config:
|
if hasattr(self, "config") and self.config:
|
||||||
|
@ -49,6 +53,14 @@ class NAR(Base):
|
||||||
def recurrent_chunk_size(self) -> int:
|
def recurrent_chunk_size(self) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
"""
|
||||||
|
@property
|
||||||
|
def rotary_embedding_base(self) -> float:
|
||||||
|
if hasattr(self, "config") and self.config:
|
||||||
|
return self.config.rotary_embedding_base
|
||||||
|
return cfg.models.nar.rotary_embedding_base
|
||||||
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def interleave(self) -> bool:
|
def interleave(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,71 +1,3 @@
|
||||||
"""
|
|
||||||
# https://github.com/facebookresearch/fairseq/blob/main/fairseq/incremental_decoding_utils.py
|
|
||||||
# Copied directly because even having fairseq installed WILL break logging, why are corposhitters like this
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from typing import Dict, Optional
|
|
||||||
|
|
||||||
from torch import Tensor
|
|
||||||
|
|
||||||
from torchscale.architecture.config import RetNetConfig
|
from torchscale.architecture.config import RetNetConfig
|
||||||
from torchscale.architecture.retnet import RetNetDecoder
|
from torchscale.architecture.retnet import RetNetDecoder
|
||||||
|
# from retnet import RetNet
|
||||||
"""
|
|
||||||
class FairseqIncrementalState(object):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.init_incremental_state()
|
|
||||||
|
|
||||||
def init_incremental_state(self):
|
|
||||||
self._incremental_state_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
def _get_full_incremental_state_key(self, key: str) -> str:
|
|
||||||
return "{}.{}".format(self._incremental_state_id, key)
|
|
||||||
|
|
||||||
def get_incremental_state(
|
|
||||||
self,
|
|
||||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
||||||
key: str,
|
|
||||||
) -> Optional[Dict[str, Optional[Tensor]]]:
|
|
||||||
full_key = self._get_full_incremental_state_key(key)
|
|
||||||
if incremental_state is None or full_key not in incremental_state:
|
|
||||||
return None
|
|
||||||
return incremental_state[full_key]
|
|
||||||
|
|
||||||
def set_incremental_state(
|
|
||||||
self,
|
|
||||||
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
|
|
||||||
key: str,
|
|
||||||
value: Dict[str, Optional[Tensor]],
|
|
||||||
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
|
|
||||||
if incremental_state is not None:
|
|
||||||
full_key = self._get_full_incremental_state_key(key)
|
|
||||||
incremental_state[full_key] = value
|
|
||||||
return incremental_state
|
|
||||||
|
|
||||||
|
|
||||||
def with_incremental_state(cls):
|
|
||||||
cls.__bases__ = (FairseqIncrementalState,) + tuple(
|
|
||||||
b for b in cls.__bases__ if b != FairseqIncrementalState
|
|
||||||
)
|
|
||||||
return cls
|
|
||||||
|
|
||||||
|
|
||||||
from torchscale.architecture.config import RetNetConfig
|
|
||||||
from torchscale.architecture.retnet import RetNetDecoder as Decoder
|
|
||||||
|
|
||||||
@with_incremental_state
|
|
||||||
class RetNetDecoder(Decoder):
|
|
||||||
def forward(self, src_tokens, **kwargs):
|
|
||||||
return super().forward(src_tokens, **kwargs)
|
|
||||||
|
|
||||||
def max_positions(self):
|
|
||||||
return self.args.max_token_positions
|
|
||||||
|
|
||||||
def reorder_incremental_state( self, incremental_state, new_order ):
|
|
||||||
for module in incremental_state:
|
|
||||||
for key in incremental_state[module]:
|
|
||||||
result = incremental_state[module][key].index_select(0, new_order)
|
|
||||||
incremental_state[module][key] = result
|
|
||||||
"""
|
|
|
@ -64,84 +64,57 @@ def load_engines(invert=False):
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
if cfg.trainer.backend == "local" or (cfg.trainer.backend == "deepspeed" and cfg.hyperparameters.torch_optimizer):
|
||||||
|
optimizer_class = None
|
||||||
|
params = {
|
||||||
|
"lr": cfg.hyperparameters.learning_rate,
|
||||||
|
}
|
||||||
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
if cfg.hyperparameters.optimizer.lower() == "adamw":
|
||||||
params = {
|
params["betas"] = (0.9, 0.96)
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
params["eps"] = 1e-07
|
||||||
"betas": (0.9, 0.96),
|
params["weight_decay"] = 0.01
|
||||||
"eps": 1e-07,
|
|
||||||
"weight_decay": 0.01,
|
optimizer_class = ml.AdamW
|
||||||
}
|
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
|
||||||
optimizer = ml.AdamW(
|
|
||||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
|
||||||
**params,
|
|
||||||
)
|
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
elif cfg.hyperparameters.optimizer.lower() == "sgd":
|
||||||
params = {
|
optimizer = ml.SGD
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
|
||||||
}
|
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
|
||||||
optimizer = ml.SGD(
|
|
||||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
|
||||||
**params,
|
|
||||||
)
|
|
||||||
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
elif cfg.hyperparameters.optimizer.lower() == "prodigy":
|
||||||
params = {
|
optimizer_class = ml.Prodigy
|
||||||
"lr": cfg.hyperparameters.learning_rate,
|
else:
|
||||||
}
|
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||||
params.update(cfg.hyperparameters.optimizer_params)
|
|
||||||
optimizer = ml.Prodigy(
|
params.update(cfg.hyperparameters.optimizer_params)
|
||||||
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
optimizer = optimizer_class(
|
||||||
**params,
|
[ param for name, param in model.named_parameters() if name not in model._cfg.frozen_params ],
|
||||||
)
|
**params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# set up our LR scheduler here
|
||||||
|
|
||||||
if not model._cfg.training:
|
if not model._cfg.training:
|
||||||
optimizer = None
|
optimizer = None
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
|
|
||||||
|
stats = None
|
||||||
if cfg.trainer.load_state_dict or not model._cfg.training:
|
if cfg.trainer.load_state_dict or not model._cfg.training:
|
||||||
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
load_path = cfg.ckpt_dir / name / "fp32.pth"
|
||||||
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
state = torch.load(load_path, map_location=torch.device(cfg.device))
|
||||||
# exporting the model from the zero_to_fp32.py exports the actual module's dict
|
|
||||||
# exporting with vall_e.export exports the state dict under .module
|
# state dict is not just the module, extract the extra trainer details
|
||||||
|
if "stats" in state:
|
||||||
|
additionals = state["stats"]
|
||||||
|
|
||||||
if "module" in state:
|
if "module" in state:
|
||||||
state = state["module"]
|
state = state["module"]
|
||||||
|
|
||||||
# should decouple the following from this trainer script
|
|
||||||
# probably with passing a fun that defaults to a lambda x: x deal
|
|
||||||
|
|
||||||
"""
|
|
||||||
# can probably be done a lot more intelligently but oh well
|
|
||||||
# extend the proms_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
|
||||||
if model.proms_emb.weight.shape[0] > state['proms_emb.weight'].shape[0] or model.proms_emb.weight.shape[1] > state['proms_emb.weight'].shape[1]:
|
|
||||||
o_prom_levels, o_prom_tokens, d_model = state['proms_emb.weight'].shape
|
|
||||||
|
|
||||||
# copy weights from the dict into the old portion
|
|
||||||
model.proms_emb.weight.data[:o_prom_levels, :o_prom_tokens, :] = state['proms_emb.weight'].data[:o_prom_levels, :o_prom_tokens, :]
|
|
||||||
# copy the full tensors back
|
|
||||||
state['proms_emb.weight'] = model.proms_emb.weight
|
|
||||||
|
|
||||||
# extend the resps_emb if we ever touch the n_prom_levels or n_prom_tokens (from adding tasks)
|
|
||||||
if model.resps_emb.weight.shape[0] > state['resps_emb.weight'].shape[0] or model.resps_emb.weight.shape[1] > state['resps_emb.weight'].shape[1]:
|
|
||||||
o_resp_levels, o_resp_tokens, d_model = state['resps_emb.weight'].shape
|
|
||||||
n_resp_levels, n_resp_tokens, d_model = model.resps_emb.weight.shape
|
|
||||||
|
|
||||||
# copy weights from the dict into the old portion
|
|
||||||
model.resps_emb.weight.data[:o_resp_levels, :o_resp_tokens, :] = state['resps_emb.weight'].data[:o_resp_levels, :o_resp_tokens, :]
|
|
||||||
# copy the full tensors back
|
|
||||||
state['resps_emb.weight'] = model.resps_emb.weight
|
|
||||||
"""
|
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# use base engine because DeepSpeed memory leaks
|
# use base engine because DeepSpeed memory leaks if it's a non-training model
|
||||||
engines[name] = (Engine if model._cfg.training else _Engine)(
|
engines[name] = (Engine if model._cfg.training else _Engine)(
|
||||||
#engines[name] = Engine(
|
|
||||||
model=model,
|
model=model,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
|
|
||||||
_cfg=model._cfg,
|
_cfg=model._cfg,
|
||||||
|
stats=stats
|
||||||
)
|
)
|
||||||
|
|
||||||
engines = Engines(engines)
|
engines = Engines(engines)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user