experimental "just have a token for what rvq level we're on" that seems to help all models (mamba almost works, but it might just have to be relegated as a pure AR model)
This commit is contained in:
parent
e0886c5a78
commit
9e3f2e300f
|
@ -167,7 +167,9 @@ class AR_NAR(Base):
|
|||
resps_list=resps_list,
|
||||
targ_list=targ_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list
|
||||
tone_list=tone_list,
|
||||
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
|
@ -193,6 +195,7 @@ class AR_NAR(Base):
|
|||
resps_list=prev_list,
|
||||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
quant_levels=quant_levels,
|
||||
)
|
||||
|
||||
logits = super().forward(
|
||||
|
@ -336,9 +339,12 @@ def example_usage():
|
|||
device = "cuda"
|
||||
|
||||
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
|
||||
"""
|
||||
if "mamba" in cfg.model.arch_type:
|
||||
cfg.model.prom_levels = 1
|
||||
cfg.model.resp_levels = 1
|
||||
"""
|
||||
cfg.model.loss_factors = {}
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
|
@ -390,7 +396,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 250
|
||||
steps = 500
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
|
|
|
@ -506,6 +506,7 @@ class Base(nn.Module):
|
|||
self.langs_emb = None
|
||||
self.tones_emb = None
|
||||
self.tasks_emb = None
|
||||
self.rvq_level_emb = None
|
||||
|
||||
if self.version == 1: # legacy
|
||||
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
|
||||
|
@ -533,6 +534,9 @@ class Base(nn.Module):
|
|||
if self.version >= 4:
|
||||
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
|
||||
|
||||
if self.version >= 5:
|
||||
self.rvq_level_emb = Embedding(self.n_resp_levels, d_model)
|
||||
|
||||
self.sep = nn.Parameter(torch.randn(d_model))
|
||||
|
||||
# ick, there has to be a better way
|
||||
|
@ -872,14 +876,21 @@ class Base(nn.Module):
|
|||
|
||||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: Tensor | None = None
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
||||
inputs = [ [] for _ in range(batch_size) ]
|
||||
for i in range(batch_size):
|
||||
quant_level = quant_levels[i] if quant_levels is not None else 0
|
||||
|
||||
if text_list is not None:
|
||||
inputs[i].append( ( "text", text_list[i] ) )
|
||||
|
||||
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
|
||||
|
||||
if proms_list is not None:
|
||||
inputs[i].append( ( "prom", proms_list[i] ) )
|
||||
if resps_list is not None:
|
||||
|
@ -897,11 +908,13 @@ class Base(nn.Module):
|
|||
x_list = []
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
batch = []
|
||||
quant_level = quant_levels[batch_index] if quant_levels is not None else None
|
||||
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
|
||||
for name, input in batch_input:
|
||||
embedding = None
|
||||
if name == "text":
|
||||
embedding = self.text_emb( input )
|
||||
elif name == "quant_level":
|
||||
embedding = self.rvq_level_emb( input )
|
||||
elif name == "lang":
|
||||
embedding = self.langs_emb( input )
|
||||
elif name == "prom":
|
||||
|
@ -929,12 +942,13 @@ class Base(nn.Module):
|
|||
# old, "naive" way, no loss factoring
|
||||
if not self.hyper_config.loss_factors:
|
||||
target_list = []
|
||||
for batch in inputs:
|
||||
for batch_index, batch in enumerate(inputs):
|
||||
target = []
|
||||
quant_level = quant_levels[batch_index] if quant_levels is not None else None
|
||||
for name, input in batch:
|
||||
if name == "prom":
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
elif name in ["text", "lang", "tone", "targ"]:
|
||||
elif name in ["text", "quant_level", "lang", "tone", "targ"]:
|
||||
target.append( input )
|
||||
|
||||
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
|
||||
|
@ -980,6 +994,7 @@ class Base(nn.Module):
|
|||
input = input[:, quant_level]
|
||||
|
||||
seq_len = input.shape[0]
|
||||
|
||||
logit = logits[i][it:it+seq_len]
|
||||
it += seq_len + 1 # +1 to incorporate the separator
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user