move NAR-len rvq level 0 to separate embedding
This commit is contained in:
parent
29e45be0b4
commit
269648605e
|
@ -260,6 +260,7 @@ class ModelExperimentalSettings:
|
||||||
|
|
||||||
masking_train_p: float = 0.0 # odds of training with masking
|
masking_train_p: float = 0.0 # odds of training with masking
|
||||||
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
masking_train_rvq_levels: list = field(default_factory=lambda: [0,0]) # determines which levels to do mask training on
|
||||||
|
masking_separate_embeddings: bool = False
|
||||||
|
|
||||||
# classifier-free guidance shit
|
# classifier-free guidance shit
|
||||||
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
cfg_cond_dropout_p: float = 0.0 # 0.2 # probability to drop out text and audio during training
|
||||||
|
|
|
@ -204,6 +204,17 @@ def load_engines(training=True, **model_kwargs):
|
||||||
continue
|
continue
|
||||||
state[k] = ml.resize_weight( state[k], tokens )
|
state[k] = ml.resize_weight( state[k], tokens )
|
||||||
|
|
||||||
|
"""
|
||||||
|
if model.config.experimental.masking_separate_embeddings and "resps_emb.embeddings.8.weight" not in state:
|
||||||
|
state['classifiers.proj.9.weight'] = state['classifiers.proj.8.weight'].clone()
|
||||||
|
state['classifiers.proj.9.bias'] = state['classifiers.proj.8.bias'].clone()
|
||||||
|
|
||||||
|
del state['classifiers.proj.8.weight']
|
||||||
|
del state['classifiers.proj.8.bias']
|
||||||
|
|
||||||
|
state['resps_emb.embeddings.8.weight'] = state['resps_emb.embeddings.0.weight'].clone()
|
||||||
|
"""
|
||||||
|
|
||||||
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
model.load_state_dict(state, strict=cfg.trainer.strict_loading)
|
||||||
|
|
||||||
# load lora weights if exists
|
# load lora weights if exists
|
||||||
|
|
|
@ -896,7 +896,7 @@ def example_usage():
|
||||||
|
|
||||||
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
text, audio = load_artifact(f"./data/qnt.{'dac' if cfg.audio_backend == 'dac' else 'enc'}")
|
||||||
batch_size = cfg.hyperparameters.batch_size
|
batch_size = cfg.hyperparameters.batch_size
|
||||||
cfg.model.experimental.masking_train_p = 0.5
|
cfg.model.experimental.masking_train_p = 1.0
|
||||||
|
|
||||||
text_list = [ text ] * batch_size
|
text_list = [ text ] * batch_size
|
||||||
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
proms_list = [ audio[:cfg.dataset.frames_per_second, :] ] * batch_size
|
||||||
|
|
|
@ -476,6 +476,8 @@ class Base(nn.Module):
|
||||||
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
layerskip_p_max = self.config.experimental.layerskip_p_max if self.config is not None else 0.1
|
||||||
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
layerskip_e_scale = self.config.experimental.layerskip_e_scale if self.config is not None else 0.1
|
||||||
|
|
||||||
|
masking_separate_embeddings = self.config.experimental.masking_separate_embeddings if self.config is not None else False
|
||||||
|
|
||||||
n_tasks = self.config.tasks if self.config is not None else 8
|
n_tasks = self.config.tasks if self.config is not None else 8
|
||||||
n_langs = self.config.langs if self.config is not None else 2
|
n_langs = self.config.langs if self.config is not None else 2
|
||||||
n_tones = self.config.tones if self.config is not None else 1
|
n_tones = self.config.tones if self.config is not None else 1
|
||||||
|
@ -484,7 +486,12 @@ class Base(nn.Module):
|
||||||
if "nar" not in self.capabilities:
|
if "nar" not in self.capabilities:
|
||||||
n_resp_tokens = n_audio_tokens + 1
|
n_resp_tokens = n_audio_tokens + 1
|
||||||
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
l_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||||
# AR+NAR model / NAR-len model
|
# NAR-len model
|
||||||
|
elif "len" in self.capabilities and masking_separate_embeddings:
|
||||||
|
# +1 to include the stop or mask token
|
||||||
|
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||||
|
l_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) + [n_resp_tokens]
|
||||||
|
# AR+NAR model
|
||||||
else:
|
else:
|
||||||
# +1 to include the stop or mask token
|
# +1 to include the stop or mask token
|
||||||
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
n_resp_tokens = n_audio_tokens + ( 1 if self.causal_size > 0 else 0 )
|
||||||
|
@ -495,6 +502,7 @@ class Base(nn.Module):
|
||||||
self.layerskip = layerskip
|
self.layerskip = layerskip
|
||||||
self.special_tasks = [ "len", "stt" ]
|
self.special_tasks = [ "len", "stt" ]
|
||||||
self.inject_timestep_embedding = False # results in bad output
|
self.inject_timestep_embedding = False # results in bad output
|
||||||
|
self.masking_separate_embeddings = masking_separate_embeddings
|
||||||
|
|
||||||
self.text_emb = Embedding(n_text_tokens, d_model)
|
self.text_emb = Embedding(n_text_tokens, d_model)
|
||||||
self.langs_emb = None
|
self.langs_emb = None
|
||||||
|
@ -1182,7 +1190,7 @@ class Base(nn.Module):
|
||||||
embedding = self.resps_emb(
|
embedding = self.resps_emb(
|
||||||
# if masked use masked token, else original token
|
# if masked use masked token, else original token
|
||||||
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
torch.where( dropout_mask, self.stop_token, input if input.dim() == 1 else input[:, 0] ),
|
||||||
offset = 0,
|
offset = -1 if self.masking_separate_embeddings else 0, # pick last
|
||||||
quant_level = 0,
|
quant_level = 0,
|
||||||
)
|
)
|
||||||
# cheat-y way to handle performing STT across all levels
|
# cheat-y way to handle performing STT across all levels
|
||||||
|
@ -1325,10 +1333,9 @@ class Base(nn.Module):
|
||||||
device = logits[0].device
|
device = logits[0].device
|
||||||
batch_size = len(logits)
|
batch_size = len(logits)
|
||||||
summed_embeddings_task = [ "stt" ]
|
summed_embeddings_task = [ "stt" ]
|
||||||
#classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
|
||||||
|
|
||||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||||
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
|
||||||
|
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
# handles tasks where the prompt has task tokens injected in the middle
|
# handles tasks where the prompt has task tokens injected in the middle
|
||||||
def prompt_input_to_token( input, quant_level ):
|
def prompt_input_to_token( input, quant_level ):
|
||||||
|
@ -1623,6 +1630,8 @@ class Base(nn.Module):
|
||||||
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
position_ids = self.inputs_to_position_ids( inputs, mask=mask ) if not self.unified_position_ids else None
|
||||||
|
|
||||||
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
tasks = [ self.get_input(inputs, "task", at=i) for i in range( batch_size ) ]
|
||||||
|
is_nar_len = [ self.get_input(inputs, "dropout_mask", at=i) is not None and self.masking_separate_embeddings for i in range( batch_size ) ]
|
||||||
|
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if tasks[i] in self.special_tasks else (-2 if is_nar_len[i] else l) for i, l in enumerate( quant_levels ) ]
|
||||||
|
|
||||||
if self.inject_timestep_embedding:
|
if self.inject_timestep_embedding:
|
||||||
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
timesteps = [ self.get_input(inputs, "timestep", at=i) for i in range( batch_size ) ]
|
||||||
|
@ -1630,8 +1639,6 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
timesteps = []
|
timesteps = []
|
||||||
|
|
||||||
classifier_quant_levels = [ -1 if tasks[i] in self.special_tasks else l for i, l in enumerate( quant_levels ) ]
|
|
||||||
|
|
||||||
output = self._forward(
|
output = self._forward(
|
||||||
inputs=x,
|
inputs=x,
|
||||||
mask=mask,
|
mask=mask,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user