diff --git a/vall_e/data.py b/vall_e/data.py index 1a643c7..ede02e8 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -36,7 +36,7 @@ from tqdm.auto import tqdm _logger = logging.getLogger(__name__) @cache -def get_random_prompts( validation=True, min_length=0, tokenized=False ): +def get_random_prompts( validation=False, min_length=0, tokenized=False ): duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range sentences = [ "The birch canoe slid on the smooth planks.", diff --git a/vall_e/engines/__init__.py b/vall_e/engines/__init__.py index 7f7c609..86fb68e 100755 --- a/vall_e/engines/__init__.py +++ b/vall_e/engines/__init__.py @@ -205,14 +205,24 @@ def load_engines(training=True, **model_kwargs): state[k] = ml.resize_weight( state[k], tokens ) """ - if model.config.experimental.masking_separate_embeddings and "resps_emb.embeddings.8.weight" not in state: + if True: + # move STT one over state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone() state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone() - - del state['classifiers.proj.8.weight'] - del state['classifiers.proj.8.bias'] - - state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone() + # copy from AR:0:0 classifier + if False: + state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone() + state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone() + # copy from AR:0:0 embeddings + state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone() + # remove + else: + if 'classifiers.proj.8.weight' in state: + del state['classifiers.proj.8.weight'] + if 'classifiers.proj.8.bias' in state: + del state['classifiers.proj.8.bias'] + if 'resps_emb.embeddings.8.weight' in state: + del state['resps_emb.embeddings.8.weight'] """ model.load_state_dict(state, strict=cfg.trainer.strict_loading) diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index ca4cd77..e5ddfc4 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -148,7 +148,7 @@ class AR_NAR(Base): resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1 # only apply stop token for RVQ level 0 - if quant_level <= 0: + if quant_level <= 0 and timesteps[i] is not None: # append stop tokens for AR if task in text_task: #text_list[i] = torch.cat([ resps, text_stop_sequence ]) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 905e639..1724fd5 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -996,24 +996,22 @@ class Base(nn.Module): inputs[i].append( ( "tone", tone_list[i] ) ) # insert timestep token if timestep is not None: - # it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence - """ # store timestep information inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) ) - """ - classifier_level = "NAR:0:0" + # force set to use this classifier level + classifier_level = "NAR:0:0" if self.masking_separate_embeddings else "AR:0:0" # insert the current output response if resps_list is not None and resps_list[i] is not None: inputs[i].append( ( "resp", resps_list[i] ) ) # store dropout mask (if training, as this gets used later to mask the input embeddings if provided) if timestep is not None and self.training: + """ # a paper said to use a fixed masking ratio for training + p = 0.8 """ # cosine scheduled timestep => masking ratio p = math.cos(timestep * math.pi * 0.5) - """ - p = 0.8 dropout_mask = _dropout_mask( resps_list[i], p ) inputs[i].append( ("dropout_mask", dropout_mask ) ) @@ -1226,7 +1224,7 @@ class Base(nn.Module): continue embedding[i] = self.dropout_token - elif name == "timestep" and self.time_emb is not None and False: + elif name == "timestep" and self.time_emb is not None: embedding = self.time_emb( input ) elif name == "len" and self.len_emb is not None: embedding = self.len_emb( input )