From b0158a61d55f07662fd46ed5a443aac117313d54 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 7 Jun 2024 20:34:36 -0500 Subject: [PATCH] fixed some logic errors with training (grabbing wrong quant level...) --- vall_e/models/base.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vall_e/models/base.py b/vall_e/models/base.py index d00a6c9..12c4726 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -664,10 +664,12 @@ class Base(nn.Module): elif name == "lang" and self.langs_emb is not None: embedding = self.langs_emb( input ) elif name == "prom": + # get RVQ level 0, or up to targetted RVQ level inference embedding = self.proms_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level if self.version >= 5 else None ) elif name == "tone" and self.tones_emb is not None: embedding = self.tones_emb( input ) elif name == "resp": + # get RVQ level 0, or up to targetted RVQ level inference embedding = self.resps_emb( input if input.dim() == 1 or quant_level == 0 else input[:, :quant_level], quant_level ) else: continue @@ -688,7 +690,10 @@ class Base(nn.Module): # old, "naive" way, no loss factoring if not self.config.loss_factors: target_list = [] + for batch_index, batch in enumerate(inputs): + quant_level = quant_levels[batch_index] + prev_quant_level = 0 if quant_level == 0 else quant_level - 1 target = [] for name, input in batch: if name == "prom": @@ -697,9 +702,9 @@ class Base(nn.Module): target.append( torch.full_like(input[..., 0], self.ignore_index) ) # we *CAN* directly map to proms else: - target.append( input if input.dim() == 1 else input[:, quant_level-1] ) + target.append( input if input.dim() == 1 else input[:, prev_quant_level] ) elif name == "resp": - target.append( input if input.dim() == 1 else input[:, quant_level-1] ) + target.append( input if input.dim() == 1 else input[:, quant_level] ) elif name in ["text", "quant_level", "lang", "tone"]: target.append( input ) @@ -729,7 +734,7 @@ class Base(nn.Module): ) else: self.loss = dict( - nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) * loss_factor for targets, inputs in zip( target_list, logits ) ]) / batch_size + nll = sum([ F.cross_entropy( inputs, targets, ignore_index=self.ignore_index ) for targets, inputs in zip( target_list, logits ) ]) / batch_size ) self.stats = dict( acc = sum( [ self.accuracy_metric( inputs, targets ) for targets, inputs in zip( target_list, logits ) ] ) / batch_size @@ -752,7 +757,8 @@ class Base(nn.Module): batch_size = len( inputs ) for i, batch in enumerate( inputs ): - quant_level = quant_levels[i] if quant_levels is not None else None + quant_level = quant_levels[i] + prev_quant_level = 0 if quant_level == 0 else quant_level - 1 it = 0 for name, input in batch: @@ -760,8 +766,8 @@ class Base(nn.Module): if name == "resp": input = input if input.dim() == 1 else input[:, quant_level] # select prom level - elif name == "prom" and quant_level is not None: - input = input[:, quant_level] + elif name == "prom": + input = input[:, prev_quant_level] seq_len = input.shape[0]