experimental "just have a token for what rvq level we're on" that seems to help all models (mamba almost works, but it might just have to be relegated as a pure AR model)

This commit is contained in:
mrq 2024-06-04 23:23:31 -05:00
parent e0886c5a78
commit 9e3f2e300f
2 changed files with 26 additions and 5 deletions

View File

@ -167,7 +167,9 @@ class AR_NAR(Base):
resps_list=resps_list,
targ_list=targ_list,
lang_list=lang_list,
tone_list=tone_list
tone_list=tone_list,
quant_levels=quant_levels,
)
return super().forward(
@ -193,6 +195,7 @@ class AR_NAR(Base):
resps_list=prev_list,
lang_list=lang_list,
tone_list=tone_list,
quant_levels=quant_levels,
)
logits = super().forward(
@ -336,9 +339,12 @@ def example_usage():
device = "cuda"
# mamba seems to ONLY be used as an AR (any NAR attempts lobotomizes it)
"""
if "mamba" in cfg.model.arch_type:
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1
"""
cfg.model.loss_factors = {}
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
@ -390,7 +396,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 250
steps = 500
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""

View File

@ -506,6 +506,7 @@ class Base(nn.Module):
self.langs_emb = None
self.tones_emb = None
self.tasks_emb = None
self.rvq_level_emb = None
if self.version == 1: # legacy
n_prom_tokens += (self.n_tasks - 1) # old models have the task tokens in the prom
@ -533,6 +534,9 @@ class Base(nn.Module):
if self.version >= 4:
self.tones_emb = Embedding(self.n_tones, d_model) if self.n_tones > 0 else None
if self.version >= 5:
self.rvq_level_emb = Embedding(self.n_resp_levels, d_model)
self.sep = nn.Parameter(torch.randn(d_model))
# ick, there has to be a better way
@ -872,14 +876,21 @@ class Base(nn.Module):
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None
):
device = text_list[0].device
batch_size = len(text_list)
inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0
if text_list is not None:
inputs[i].append( ( "text", text_list[i] ) )
inputs[i].append( ( "quant_level", torch.Tensor([ quant_level ]).to(device=device, dtype=torch.int16) ) )
if proms_list is not None:
inputs[i].append( ( "prom", proms_list[i] ) )
if resps_list is not None:
@ -897,11 +908,13 @@ class Base(nn.Module):
x_list = []
for batch_index, batch_input in enumerate(inputs):
batch = []
quant_level = quant_levels[batch_index] if quant_levels is not None else None
quant_level = quant_levels[batch_index] if quant_levels is not None else 0
for name, input in batch_input:
embedding = None
if name == "text":
embedding = self.text_emb( input )
elif name == "quant_level":
embedding = self.rvq_level_emb( input )
elif name == "lang":
embedding = self.langs_emb( input )
elif name == "prom":
@ -929,12 +942,13 @@ class Base(nn.Module):
# old, "naive" way, no loss factoring
if not self.hyper_config.loss_factors:
target_list = []
for batch in inputs:
for batch_index, batch in enumerate(inputs):
target = []
quant_level = quant_levels[batch_index] if quant_levels is not None else None
for name, input in batch:
if name == "prom":
target.append( torch.full_like(input[..., 0], self.ignore_index) )
elif name in ["text", "lang", "tone", "targ"]:
elif name in ["text", "quant_level", "lang", "tone", "targ"]:
target.append( input )
target_list.append( _join( target, torch.tensor(self.ignore_index, device=target[-1].device) ) )
@ -980,6 +994,7 @@ class Base(nn.Module):
input = input[:, quant_level]
seq_len = input.shape[0]
logit = logits[i][it:it+seq_len]
it += seq_len + 1 # +1 to incorporate the separator