ugh
This commit is contained in:
parent
c00fc18b62
commit
e412e98125
|
@ -36,7 +36,7 @@ from tqdm.auto import tqdm
|
|||
_logger = logging.getLogger(__name__)
|
||||
|
||||
@cache
|
||||
def get_random_prompts( validation=True, min_length=0, tokenized=False ):
|
||||
def get_random_prompts( validation=False, min_length=0, tokenized=False ):
|
||||
duration_range = [ 5.5, 12.0 ] # to-do: pull from cfg.dataset.duration_range
|
||||
sentences = [
|
||||
"The birch canoe slid on the smooth planks.",
|
||||
|
|
|
@ -205,14 +205,24 @@ def load_engines(training=True, **model_kwargs):
|
|||
state[k] = ml.resize_weight( state[k], tokens )
|
||||
|
||||
"""
|
||||
if model.config.experimental.masking_separate_embeddings and "resps_emb.embeddings.8.weight" not in state:
|
||||
if True:
|
||||
# move STT one over
|
||||
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
||||
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
||||
|
||||
del state['classifiers.proj.8.weight']
|
||||
del state['classifiers.proj.8.bias']
|
||||
|
||||
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
||||
# copy from AR:0:0 classifier
|
||||
if False:
|
||||
state['classifiers.proj.8.weight'] = state['classifiers.proj.0.weight'].clone()
|
||||
state['classifiers.proj.8.bias'] = state['classifiers.proj.0.bias'].clone()
|
||||
# copy from AR:0:0 embeddings
|
||||
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
||||
# remove
|
||||
else:
|
||||
if 'classifiers.proj.8.weight' in state:
|
||||
del state['classifiers.proj.8.weight']
|
||||
if 'classifiers.proj.8.bias' in state:
|
||||
del state['classifiers.proj.8.bias']
|
||||
if 'resps_emb.embeddings.8.weight' in state:
|
||||
del state['resps_emb.embeddings.8.weight']
|
||||
"""
|
||||
|
||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||
|
|
|
@ -148,7 +148,7 @@ class AR_NAR(Base):
|
|||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||
|
||||
# only apply stop token for RVQ level 0
|
||||
if quant_level <= 0:
|
||||
if quant_level <= 0 and timesteps[i] is not None:
|
||||
# append stop tokens for AR
|
||||
if task in text_task:
|
||||
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
|
||||
|
|
|
@ -996,24 +996,22 @@ class Base(nn.Module):
|
|||
inputs[i].append( ( "tone", tone_list[i] ) )
|
||||
# insert timestep token
|
||||
if timestep is not None:
|
||||
# it does not seem to matter whether this is provided or not, I assume the model attends more to the amount of masked tokens in the sequence
|
||||
"""
|
||||
# store timestep information
|
||||
inputs[i].append( ("timestep", torch.tensor([timestep], device=device, dtype=self.time_emb.mlp[0].weight.dtype) ) )
|
||||
"""
|
||||
classifier_level = "NAR:0:0"
|
||||
# force set to use this classifier level
|
||||
classifier_level = "NAR:0:0" if self.masking_separate_embeddings else "AR:0:0"
|
||||
# insert the current output response
|
||||
if resps_list is not None and resps_list[i] is not None:
|
||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||
|
||||
# store dropout mask (if training, as this gets used later to mask the input embeddings if provided)
|
||||
if timestep is not None and self.training:
|
||||
"""
|
||||
# a paper said to use a fixed masking ratio for training
|
||||
p = 0.8
|
||||
"""
|
||||
# cosine scheduled timestep => masking ratio
|
||||
p = math.cos(timestep * math.pi * 0.5)
|
||||
"""
|
||||
p = 0.8
|
||||
dropout_mask = _dropout_mask( resps_list[i], p )
|
||||
inputs[i].append( ("dropout_mask", dropout_mask ) )
|
||||
|
||||
|
@ -1226,7 +1224,7 @@ class Base(nn.Module):
|
|||
continue
|
||||
|
||||
embedding[i] = self.dropout_token
|
||||
elif name == "timestep" and self.time_emb is not None and False:
|
||||
elif name == "timestep" and self.time_emb is not None:
|
||||
embedding = self.time_emb( input )
|
||||
elif name == "len" and self.len_emb is not None:
|
||||
embedding = self.len_emb( input )
|
||||
|
|
Loading…
Reference in New Issue
Block a user