madness
This commit is contained in:
parent
9e3f2e300f
commit
48cd1054f9
|
@ -214,9 +214,11 @@ class Model:
|
|||
attention: str = "auto"
|
||||
audio_embedding_sums: bool = True
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 })
|
||||
kv_heads: int = 0
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
experimental: bool = False # for now it sets things to be HF compatible
|
||||
kv_heads: int = 0
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
|
|
@ -15,6 +15,8 @@ from ..emb.qnt import trim
|
|||
class AR_NAR(Base):
|
||||
@property
|
||||
def causal(self):
|
||||
if hasattr(self, "config") and self.config:
|
||||
return "ar" in self.capabilities
|
||||
return True
|
||||
|
||||
@property
|
||||
|
@ -135,9 +137,9 @@ class AR_NAR(Base):
|
|||
index = i
|
||||
return int(index)
|
||||
|
||||
quant_levels = torch.Tensor([ generate(0, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
||||
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
||||
else:
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
"""
|
||||
if cfg.model.p_ar_level == "auto" or cfg.model.p_ar_level is None:
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
|
@ -344,7 +346,7 @@ def example_usage():
|
|||
cfg.model.prom_levels = 1
|
||||
cfg.model.resp_levels = 1
|
||||
"""
|
||||
cfg.model.loss_factors = {}
|
||||
# cfg.model.loss_factors = {}
|
||||
|
||||
def tokenize(content):
|
||||
return torch.tensor( cfg.tokenizer.encode(content) )
|
||||
|
@ -396,7 +398,7 @@ def example_usage():
|
|||
"""
|
||||
|
||||
model = AR_NAR(**kwargs).to(device)
|
||||
steps = 500
|
||||
steps = 200
|
||||
|
||||
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.cfg_path is not None else "prodigy"
|
||||
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.cfg_path is not None else ""
|
||||
|
@ -468,7 +470,11 @@ def example_usage():
|
|||
return
|
||||
|
||||
engine.eval()
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
if "ar" in cfg.model.capabilities:
|
||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||
else:
|
||||
resps_list = [ qnt[:, 0].to( device ) ]
|
||||
|
||||
if cfg.model.max_levels > 1:
|
||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||
resps_list = engine( text_list, proms_list, resps_list=resps_list, sampling_temperature=0.2 )
|
||||
|
@ -492,7 +498,7 @@ def example_usage():
|
|||
'module': model.state_dict()
|
||||
}, f"./data/{cfg.model.arch_type}.pth" )
|
||||
|
||||
sample("init", 5)
|
||||
#sample("init", 5)
|
||||
train()
|
||||
sample("final")
|
||||
|
||||
|
|
|
@ -498,7 +498,6 @@ class Base(nn.Module):
|
|||
self.l_padding = l_padding
|
||||
|
||||
# +1 to include the stop token
|
||||
# to-do: undo this dogshit mistake; tasks tokens should be delegated to its own embedding
|
||||
n_prom_tokens = n_tokens
|
||||
n_resp_tokens = n_tokens + (1 if self.causal else 0) # AR requires a stop token to... know when to stop
|
||||
|
||||
|
@ -1009,11 +1008,14 @@ class Base(nn.Module):
|
|||
"logits": [],
|
||||
}
|
||||
|
||||
info[name]["targets"].append( input.contiguous() )
|
||||
info[name]["logits"].append( logit.contiguous() )
|
||||
info[name]["targets"].append( input ) # input.contiguous()
|
||||
info[name]["logits"].append( logit ) # logit.contiguous()
|
||||
|
||||
for name, batch in info.items():
|
||||
loss_factor = self.loss_factor(name)
|
||||
if name not in ["text", "prom", "resp"]:
|
||||
continue
|
||||
|
||||
if loss_factor == 0.0:
|
||||
continue
|
||||
|
||||
|
@ -1021,7 +1023,11 @@ class Base(nn.Module):
|
|||
inputs = torch.cat( batch["logits"] )
|
||||
|
||||
self.loss[name] = F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor
|
||||
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||
try:
|
||||
self.stats["acc"][name] = self.accuracy_metric( inputs, targets )
|
||||
except Exception as e:
|
||||
print( name, inputs.shape, targets.shape, e )
|
||||
pass
|
||||
|
||||
# to-do: compute loss per individual batch to scale per RVQ level
|
||||
"""
|
||||
|
|
|
@ -434,7 +434,7 @@ def example_usage():
|
|||
stats = {"step": i}
|
||||
|
||||
batch_size = len(text_list)
|
||||
quant_levels = None if cfg.model.interleave else torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||
quant_levels = None if cfg.model.interleave else torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
|
||||
if quant_levels is not None:
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, resp_list) ]
|
||||
else:
|
||||
|
|
|
@ -32,7 +32,7 @@ def train_feeder(engine, batch):
|
|||
quant_levels = None
|
||||
resps_list = [ resp for resp in batch["resps"] ]
|
||||
else:
|
||||
quant_levels = torch.randint(0, cfg.model.max_levels, (batch_size,))
|
||||
quant_levels = torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
|
|
Loading…
Reference in New Issue
Block a user