I forgot the actual reason I was cleaning things up was to re-include prom loss calculation (I realized the reason I did this was because of an prom embedding oversight, it seems to work now)
This commit is contained in:
parent
da8242d086
commit
eafa622be2
|
@ -204,7 +204,7 @@ class Model:
|
||||||
attention: str = "auto"
|
attention: str = "auto"
|
||||||
audio_embedding_sums: bool = True
|
audio_embedding_sums: bool = True
|
||||||
dropout: float = 0.1 # adjustable dropout value
|
dropout: float = 0.1 # adjustable dropout value
|
||||||
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
#loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good
|
||||||
loss_factors: dict = field(default_factory=lambda: {})
|
loss_factors: dict = field(default_factory=lambda: {})
|
||||||
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
capabilities: list = field(default_factory=lambda: ["ar", "nar"])
|
||||||
experimental: bool = False # for now it sets things to be HF compatible
|
experimental: bool = False # for now it sets things to be HF compatible
|
||||||
|
|
|
@ -692,7 +692,12 @@ class Base(nn.Module):
|
||||||
target = []
|
target = []
|
||||||
for name, input in batch:
|
for name, input in batch:
|
||||||
if name == "prom":
|
if name == "prom":
|
||||||
|
# ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens
|
||||||
|
if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums):
|
||||||
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
target.append( torch.full_like(input[..., 0], self.ignore_index) )
|
||||||
|
# we *CAN* directly map to proms
|
||||||
|
else:
|
||||||
|
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||||
elif name == "resp":
|
elif name == "resp":
|
||||||
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
target.append( input if input.dim() == 1 else input[:, quant_level-1] )
|
||||||
elif name in ["text", "quant_level", "lang", "tone"]:
|
elif name in ["text", "quant_level", "lang", "tone"]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user