From 591d3ac848f12bded77246d513139537863aa150 Mon Sep 17 00:00:00 2001 From: mrq Date: Fri, 28 Jun 2024 22:44:00 -0500 Subject: [PATCH] have eval dataloader use eval batch size for batchedordersampler --- vall_e/data.py | 2 +- vall_e/models/ar_nar.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vall_e/data.py b/vall_e/data.py index 6ef6392..a494023 100755 --- a/vall_e/data.py +++ b/vall_e/data.py @@ -547,7 +547,7 @@ class Dataset(_Dataset): if self.sampler_type == "path": if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: - self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size ) + self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size ) else: self.sampler = OrderedSampler( len(self) ) self.samplers = {} diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 72b9831..3112ff1 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -278,7 +278,7 @@ class AR_NAR(Base): stop_token = self.stop_token task_list = [ "tts" for _ in range(batch_size) ] - recurrent_state = [] if cfg.inference.recurrent_forward else None + state = None mirostat = [ {"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} ] * batch_size if sampling_mirostat_tau > 0.0 else None @@ -308,15 +308,15 @@ class AR_NAR(Base): quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ) - if recurrent_state is not None: - logits, recurrent_state = super().forward( + if state is not None: + logits, state = super().forward( inputs=inputs, - state=recurrent_state, + state=state, ) else: logits = super().forward( inputs=inputs, - state=recurrent_state, + state=state, ) r = super().sample(