diff --git a/vall_e/models/base.py b/vall_e/models/base.py index 3e808b4..c7964e9 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -905,9 +905,9 @@ class Base(nn.Module): quant_level = quant_levels[bi] if quant_levels is not None else None - if name in ["text" ]: + if name == "text": text_batch.append( input ) - elif name == "prom": # and (quant_level is None or quant_level == 0) and not self.config.audio_embedding_sums: + elif name == "prom": prom_batch.append( input[:, quant_level] if quant_level is not None else input ) elif name == "targ": resp_batch.append( input ) @@ -1001,23 +1001,15 @@ class Base(nn.Module): logits_prom = [] logits_resp = [] + # trim logits to each section for i, logit in enumerate(logits): - text_len = target_text_list[i].shape[0] if target_text_list else 0 - prom_len = target_prom_list[i].shape[0] if target_prom_list else 0 - resp_len = target_resp_list[i].shape[0] if target_resp_list else 0 + text_len = target_text_list[i].shape[0] + prom_len = target_prom_list[i].shape[0] + resp_len = target_resp_list[i].shape[0] - if target_text_list: - logit_text = logit[:text_len] - logits_text.append( logit_text ) - - # + 1 to include separator - if target_prom_list: - logit_prom = logit[text_len+1:text_len+1+prom_len] - logits_prom.append( logit_prom ) - - if target_resp_list: - logit_resp = logit[-resp_len:] - logits_resp.append( logit_resp ) + logits_text.append( logit[:text_len] ) + logits_prom.append( logit[text_len+1:text_len+1+prom_len] ) # + 1 to include separator + logits_resp.append( logit[-resp_len:] ) # modify only for the AR so it can properly behave like a transformer