diff --git a/vall_e/config.py b/vall_e/config.py index d115912..2d4e28d 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -204,7 +204,7 @@ class Model: attention: str = "auto" audio_embedding_sums: bool = True dropout: float = 0.1 # adjustable dropout value - #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good + #loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 1.0, "resp": 1.0 }) # disable it by default since it causes a little more harm than good loss_factors: dict = field(default_factory=lambda: {}) capabilities: list = field(default_factory=lambda: ["ar", "nar"]) experimental: bool = False # for now it sets things to be HF compatible diff --git a/vall_e/models/base.py b/vall_e/models/base.py index b1dec02..d00a6c9 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -692,7 +692,12 @@ class Base(nn.Module): target = [] for name, input in batch: if name == "prom": - target.append( torch.full_like(input[..., 0], self.ignore_index) ) + # ignore prom, fill with mock tokens, because the prom embeddings don't directly map to tokens + if self.version < 4 or (self.version >= 5 and self.config.audio_embedding_sums): + target.append( torch.full_like(input[..., 0], self.ignore_index) ) + # we *CAN* directly map to proms + else: + target.append( input if input.dim() == 1 else input[:, quant_level-1] ) elif name == "resp": target.append( input if input.dim() == 1 else input[:, quant_level-1] ) elif name in ["text", "quant_level", "lang", "tone"]: