This commit is contained in:
mrq 2024-11-14 07:34:22 -06:00
parent c00fc18b62
commit e412e98125
4 changed files with 23 additions and 15 deletions

View File

@ -36,7 +36,7 @@ from tqdm.auto import tqdm
_logger = logging.getLogger(__name__)
@cache
def get_random_prompts( validation=True, min_length=0, tokenized=False ):
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
sentences = [
"The birch canoe slid on the smooth planks.",

View File

@ -205,14 +205,24 @@ def load_engines(training=True, **model_kwargs):
state[k] = ml.resize_weight( state[k], tokens )
"""
if model.config.experimental.masking_separate_embeddings and "resps_emb.embeddings.8.weight" not in state:
if True:
# move STT one over
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
del state['classifiers.proj.8.weight']
del state['classifiers.proj.8.bias']
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
# copy from AR:0:0 classifier
if False:
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
# copy from AR:0:0 embeddings
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
# remove
else:
if 'classifiers.proj.8.weight' in state:
del state['classifiers.proj.8.weight']
if 'classifiers.proj.8.bias' in state:
del state['classifiers.proj.8.bias']
if 'resps_emb.embeddings.8.weight' in state:
del state['resps_emb.embeddings.8.weight']
"""
model.load_state_dict(state, strict=cfg.trainer.strict_loading)

View File

@ -148,7 +148,7 @@ class AR_NAR(Base):
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
# only apply stop token for RVQ level 0
if quant_level <= 0:
if quant_level <= 0 and timesteps[i] is not None:
# append stop tokens for AR
if task in text_task:
#text_list[i] = torch.cat([ resps, text_stop_sequence ])

View File

@ -996,24 +996,22 @@ class Base(nn.Module):
inputs[i].append( ( "tone", tone_list[i] ) )
# insert timestep token
if timestep is not None:
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
"""
# store timestep information
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
"""
classifier_level = "NAR:0:0"
# force set to use this classifier level
classifier_level = "NAR:0:0" if self.masking_separate_embeddings else "AR:0:0"
# insert the current output response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
if timestep is not None and self.training:
"""
# a paper said to use a fixed masking ratio for training
p = 0.8
"""
# cosine scheduled timestep => masking ratio
p = math.cos(timestep * math.pi * 0.5)
"""
p = 0.8
dropout_mask = _dropout_mask( resps_list[i], p )
inputs[i].append( ("dropout_mask", dropout_mask ) )
@ -1226,7 +1224,7 @@ class Base(nn.Module):
continue
embedding[i] = self.dropout_token
elif name == "timestep" and self.time_emb is not None and False:
elif name == "timestep" and self.time_emb is not None:
embedding = self.time_emb( input )
elif name == "len" and self.len_emb is not None:
embedding = self.len_emb( input )