diff --git a/vall_e/models/base.py b/vall_e/models/base.py index cd1dc69..3e808b4 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -907,7 +907,7 @@ class Base(nn.Module): if name in ["text" ]: text_batch.append( input ) - elif name == "prom" and (quant_level is None or quant_level == 0 or not self.config.audio_embedding_sums): + elif name == "prom": # and (quant_level is None or quant_level == 0) and not self.config.audio_embedding_sums: prom_batch.append( input[:, quant_level] if quant_level is not None else input ) elif name == "targ": resp_batch.append( input )