have eval dataloader use eval batch size for batchedordersampler

This commit is contained in:
mrq 2024-06-28 22:44:00 -05:00
parent 1a392b69f6
commit 591d3ac848
2 changed files with 6 additions and 6 deletions

View File

@ -547,7 +547,7 @@ class Dataset(_Dataset):
if self.sampler_type == "path":
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size )
else:
self.sampler = OrderedSampler( len(self) )
self.samplers = {}

View File

@ -278,7 +278,7 @@ class AR_NAR(Base):
stop_token = self.stop_token
task_list = [ "tts" for _ in range(batch_size) ]
recurrent_state = [] if cfg.inference.recurrent_forward else None
state = None
mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None
@ -308,15 +308,15 @@ class AR_NAR(Base):
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
)
if recurrent_state is not None:
logits, recurrent_state = super().forward(
if state is not None:
logits, state = super().forward(
inputs=inputs,
state=recurrent_state,
state=state,
)
else:
logits = super().forward(
inputs=inputs,
state=recurrent_state,
state=state,
)
r = super().sample(