I forgot the actual reason I was cleaning things up was to re-include prom loss calculation (I realized the reason I did this was because of an prom embedding oversight, it seems to work now)

This commit is contained in:
mrq 2024-06-07 20:29:25 -05:00
parent da8242d086
commit eafa622be2
2 changed files with 7 additions and 2 deletions

View File

@ -204,7 +204,7 @@ class Model:
attention: str = "auto"
audio_embedding_sums: bool = True
dropout: float = 0.1 # adjustable dropout value
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
loss_factors: dict = field(default_factory=lambda: {})
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
experimental: bool = False # for now it sets things to be HF compatible

View File

@ -692,7 +692,12 @@ class Base(nn.Module):
target = []
for name, input in batch:
if name == "prom":
target.append( torch.full_like(input[..., 0], self.ignore_index) )
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums):
target.append( torch.full_like(input[..., 0], self.ignore_index) )
# we *CAN* directly map to proms
else:
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
elif name == "resp":
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
elif name in ["text", "quant_level", "lang", "tone"]: