I forgot the actual reason I was cleaning things up was to re-include prom loss calculation (I realized the reason I did this was because of an prom embedding oversight, it seems to work now)
This commit is contained in:
parent
da8242d086
commit
eafa622be2
|
@ -204,7 +204,7 @@ class Model:
|
|||
attention: str = "auto"
|
||||
audio_embedding_sums: bool = True
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||
experimental: bool = False # for now it sets things to be HF compatible
|
||||
|
|
|
@ -692,7 +692,12 @@ class Base(nn.Module):
|
|||
target = []
|
||||
for name, input in batch:
|
||||
if name == "prom":
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||
if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums):
|
||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||
# we *CAN* directly map to proms
|
||||
else:
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||
elif name == "resp":
|
||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||
elif name in ["text", "quant_level", "lang", "tone"]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user