diff --git a/vall_e/config.py b/vall_e/config.py index 7f03e3f..629d9fd 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -213,7 +213,7 @@ class Model: attention: str = "auto" audio_embedding_sums: bool = True dropout: float = 0.1 # adjustable dropout value - loss_factors: dict = field(default_factory=lambda: {}) + loss_factors: dict = field(default_factory=lambda: { "text": 0.1, "prom": 0.1, "resp": 1.0 }) def get(self, name=None): return [ self ] if not name or self.name == name else [] @@ -495,11 +495,6 @@ class DeepSpeed: return ds_cfg -@dataclass() -class LossFactor: - text: float = 1.0 - resp: float = 1.0 - @dataclass() class Trainer: iterations: int = 100_000 diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 97051a0..cd1dc69 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -888,27 +888,38 @@ class Base(nn.Module): def training_targets_split( self, inputs: list, + quant_levels: Tensor | None = None ): text_lists = [] + prom_lists = [] resp_lists = [] for bi in range(len(inputs)): text_batch = [] + prom_batch = [] resp_batch = [] for i in range(len(inputs[bi])): name, input = inputs[bi][i] device = input.device - if name in ["text", "lang" ]: + quant_level = quant_levels[bi] if quant_levels is not None else None + + if name in ["text" ]: text_batch.append( input ) + elif name == "prom" and (quant_level is None or quant_level == 0 or not self.config.audio_embedding_sums): + prom_batch.append( input[:, quant_level] if quant_level is not None else input ) elif name == "targ": resp_batch.append( input ) - text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) ) - resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) ) + if text_batch: + text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) ) + if prom_batch: + prom_lists.append( _join( prom_batch, torch.tensor(self.ignore_index, device=device) ) ) + if resp_batch: + resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) ) - return text_lists, resp_lists + return text_lists, prom_lists, resp_lists def forward( self, @@ -982,12 +993,32 @@ class Base(nn.Module): # precision = self.precision_metric( inputs, target ), ) # split our loss + # to-do: clean this up else: - target_text_list, target_resp_list = self.training_targets_split( inputs ) + target_text_list, target_prom_list, target_resp_list = self.training_targets_split( inputs, quant_levels ) + + logits_text = [] + logits_prom = [] + logits_resp = [] + + for i, logit in enumerate(logits): + text_len = target_text_list[i].shape[0] if target_text_list else 0 + prom_len = target_prom_list[i].shape[0] if target_prom_list else 0 + resp_len = target_resp_list[i].shape[0] if target_resp_list else 0 + + if target_text_list: + logit_text = logit[:text_len] + logits_text.append( logit_text ) + + # + 1 to include separator + if target_prom_list: + logit_prom = logit[text_len+1:text_len+1+prom_len] + logits_prom.append( logit_prom ) + + if target_resp_list: + logit_resp = logit[-resp_len:] + logits_resp.append( logit_resp ) - # grab respective slice of logits - logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ] - logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ] # modify only for the AR so it can properly behave like a transformer for i in range(len(target_text_list)): @@ -996,25 +1027,37 @@ class Base(nn.Module): # shift the target so that token n... logits_text[i] = logits_text[i][..., :-1, :] + logits_prom[i] = logits_prom[i][..., :-1, :] logits_resp[i] = logits_resp[i][..., :-1, :] # predicts token n + 1 target_text_list[i] = target_text_list[i][..., 1:] + target_prom_list[i] = target_prom_list[i][..., 1:] target_resp_list[i] = target_resp_list[i][..., 1:] + + self.loss = dict() + self.stats = dict(acc = dict()) - target_text = torch.cat( target_text_list ).long() - target_resp = torch.cat( target_resp_list ).long() + loss_factor_text = self.loss_factor("text") + if loss_factor_text > 0.0 and target_text_list: + target_text = torch.cat( target_text_list ).long() + inputs_text = torch.cat( logits_text ) + self.loss["text"] = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ) + self.stats["acc"]["text"] = self.accuracy_metric( inputs_text, target_text ) - inputs_text = torch.cat( logits_text ) - inputs_resp = torch.cat( logits_resp ) + loss_factor_prom = self.loss_factor("prom") + if loss_factor_prom > 0.0 and target_prom_list: + target_prom = torch.cat( target_prom_list ).long() + inputs_prom = torch.cat( logits_prom ) + self.loss["prom"] = F.cross_entropy( inputs_prom, target_prom, ignore_index=self.ignore_index ) + self.stats["acc"]["prom"] = self.accuracy_metric( inputs_prom, target_prom ) - self.loss = dict( - text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ), - resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ), - ) - - for k in self.loss: - self.loss[k] *= self.loss_factor(k) + loss_factor_resp = self.loss_factor("resp") + if loss_factor_resp > 0.0 and target_resp_list: + target_resp = torch.cat( target_resp_list ).long() + inputs_resp = torch.cat( logits_resp ) + self.loss["resp"] = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ) + self.stats["acc"]["resp"] = self.accuracy_metric( inputs_resp, target_resp ) # to-do: compute loss per individual batch to scale per RVQ level """ @@ -1023,13 +1066,6 @@ class Base(nn.Module): ... """ - self.stats = dict( - acc = dict( - text = self.accuracy_metric( inputs_text, target_text ), - resp = self.accuracy_metric( inputs_resp, target_resp ), - ), - ) - # include any additional losses (for example: MoE router) if aux_loss is not None: self.loss["aux_loss"] = aux_loss