have eval dataloader use eval batch size for batchedordersampler

This commit is contained in:
mrq 2024-06-28 22:44:00 -05:00
parent 1a392b69f6
commit 591d3ac848
2 changed files with 6 additions and 6 deletions

View File

@ -547,7 +547,7 @@ class Dataset(_Dataset):
if self.sampler_type == "path": if self.sampler_type == "path":
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0: if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size ) self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size )
else: else:
self.sampler = OrderedSampler( len(self) ) self.sampler = OrderedSampler( len(self) )
self.samplers = {} self.samplers = {}

View File

@ -278,7 +278,7 @@ class AR_NAR(Base):
stop_token = self.stop_token stop_token = self.stop_token
task_list = [ "tts" for _ in range(batch_size) ] task_list = [ "tts" for _ in range(batch_size) ]
recurrent_state = [] if cfg.inference.recurrent_forward else None state = None
mirostat = [ mirostat = [
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0} {"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
] * batch_size if sampling_mirostat_tau > 0.0 else None ] * batch_size if sampling_mirostat_tau > 0.0 else None
@ -308,15 +308,15 @@ class AR_NAR(Base):
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
) )
if recurrent_state is not None: if state is not None:
logits, recurrent_state = super().forward( logits, state = super().forward(
inputs=inputs, inputs=inputs,
state=recurrent_state, state=state,
) )
else: else:
logits = super().forward( logits = super().forward(
inputs=inputs, inputs=inputs,
state=recurrent_state, state=state,
) )
r = super().sample( r = super().sample(