separate mask token and stop token because this might cause issues
This commit is contained in:
parent
6634d07576
commit
3019c88799
|
@ -148,8 +148,9 @@ def load_engines(training=True, **model_kwargs):
|
|||
else:
|
||||
raise ValueError(f'Optimizer specified not implemented: {cfg.hyperparameters.optimizer}')
|
||||
|
||||
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
if cfg.hyperparameters.optimizer_params is not None:
|
||||
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
|
||||
params.update(cfg.hyperparameters.optimizer_params)
|
||||
|
||||
if muon_params is not None:
|
||||
muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 and f'model.{name}' not in model.config.frozen_params ]
|
||||
|
|
|
@ -1428,7 +1428,7 @@ def example_usage():
|
|||
available_tasks = ["tts-nar"]
|
||||
|
||||
model = AR_NAR(**kwargs).to(cfg.device)
|
||||
steps = 100 // batch_size
|
||||
steps = 250 // batch_size
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
|
||||
|
@ -1473,7 +1473,7 @@ def example_usage():
|
|||
|
||||
_logger.info(f"Optimizer: {optimizer}\tLearning rate: {learning_rate}")
|
||||
|
||||
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None)
|
||||
muon_params = cfg.hyperparameters.optimizer_params.pop("muon", None) if cfg.hyperparameters.optimizer_params is not None else None
|
||||
if muon_params is not None:
|
||||
muon_params["params"] = [ param for name, param in model.model.named_parameters() if param.ndim >= 2 ]
|
||||
adam_params = [ param for name, param in model.model.named_parameters() if param.ndim < 2 ] + [ param for name, param in model.named_parameters() if not name.startswith('model.') ]
|
||||
|
|
|
@ -547,6 +547,7 @@ class Base(nn.Module):
|
|||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||
|
||||
self.stop_token = self.n_audio_tokens # id 1024
|
||||
self.mask_token = self.n_audio_tokens + 1 # id 1024
|
||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||
self.version = self.config.version if self.config is not None else 5
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||
|
@ -715,7 +716,7 @@ class Base(nn.Module):
|
|||
if self.version >= 7:
|
||||
if monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
|
@ -726,7 +727,7 @@ class Base(nn.Module):
|
|||
token_dim=d_model,
|
||||
)
|
||||
self.resps_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 1, # masked token
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
n_levels=self.n_resp_levels,
|
||||
token_dim=d_model,
|
||||
)
|
||||
|
@ -1310,9 +1311,9 @@ class Base(nn.Module):
|
|||
elif name == "resp":
|
||||
if self.version >= 7:
|
||||
if self.audio_emb is not None:
|
||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||
embedding = self.audio_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
else:
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.stop_token )
|
||||
embedding = self.resps_emb( input, dropout_mask=dropout_mask, dropout_token=self.mask_token )
|
||||
# if training NAR-len RVQ level 0
|
||||
elif dropout_mask is not None:
|
||||
embedding = self.resps_emb(
|
||||
|
|
|
@ -132,9 +132,8 @@ except Exception as e:
|
|||
try:
|
||||
from muon import Muon as Muon
|
||||
except Exception as e:
|
||||
raise e
|
||||
#_logger.warning(f'Error while importing Muon: {str(e)}')
|
||||
#pass
|
||||
_logger.warning(f'Error while importing Muon: {str(e)}')
|
||||
pass
|
||||
|
||||
# https://github.com/konstmish/prodigy
|
||||
try:
|
||||
|
|
Loading…
Reference in New Issue
Block a user