This commit is contained in:
mrq 2024-06-04 23:48:51 -05:00
parent 9e3f2e300f
commit 48cd1054f9
5 changed files with 28 additions and 14 deletions

View File

@ -214,9 +214,11 @@ class Model:
attention: str = "auto"
audio_embedding_sums: bool = True
dropout: float = 0.1 # adjustable dropout value
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
kv_heads: int = 0
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: bool = False # for now it sets things to be HF compatible
kv_heads: int = 0
def get(self, name=None):
return [ self ] if not name or self.name == name else []

View File

@ -15,6 +15,8 @@ from ..emb.qnt import trim
class AR_NAR(Base):
@property
def causal(self):
if hasattr(self, "config") and self.config:
return "ar" in self.capabilities
return True
@property
@ -135,9 +137,9 @@ class AR_NAR(Base):
index = i
return int(index)
quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
else:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
"""
if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
@ -344,7 +346,7 @@ def example_usage():
cfg.model.prom_levels = 1
cfg.model.resp_levels = 1
"""
cfg.model.loss_factors = {}
# cfg.model.loss_factors = {}
def tokenize(content):
return torch.tensor( cfg.tokenizer.encode(content) )
@ -396,7 +398,7 @@ def example_usage():
"""
model = AR_NAR(**kwargs).to(device)
steps = 500
steps = 200
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
@ -468,7 +470,11 @@ def example_usage():
return
engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
if "ar" in cfg.model.capabilities:
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
else:
resps_list = [ qnt[:, 0].to( device ) ]
if cfg.model.max_levels > 1:
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
@ -492,7 +498,7 @@ def example_usage():
'module': model.state_dict()
}, f"./data/{cfg.model.arch_type}.pth" )
sample("init", 5)
#sample("init", 5)
train()
sample("final")

View File

@ -498,7 +498,6 @@ class Base(nn.Module):
self.l_padding = l_padding
# +1 to include the stop token
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
n_prom_tokens = n_tokens
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
@ -1009,11 +1008,14 @@ class Base(nn.Module):
"logits": [],
}
info[name]["targets"].append( input.contiguous() )
info[name]["logits"].append( logit.contiguous() )
info[name]["targets"].append( input ) # input.contiguous()
info[name]["logits"].append( logit ) # logit.contiguous()
for name, batch in info.items():
loss_factor = self.loss_factor(name)
if name not in ["text", "prom", "resp"]:
continue
if loss_factor == 0.0:
continue
@ -1021,7 +1023,11 @@ class Base(nn.Module):
inputs = torch.cat( batch["logits"] )
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
try:
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
except Exception as e:
print( name, inputs.shape, targets.shape, e )
pass
# to-do: compute loss per individual batch to scale per RVQ level
"""

View File

@ -434,7 +434,7 @@ def example_usage():
stats = {"step": i}
batch_size = len(text_list)
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
quant_levels = None if cfg.model.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
if quant_levels is not None:
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
else:

View File

@ -32,7 +32,7 @@ def train_feeder(engine, batch):
quant_levels = None
resps_list = [ resp for resp in batch["resps"] ]
else:
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
quant_levels = torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
input_ids, attention_mask = fold_inputs(