have eval dataloader use eval batch size for batchedordersampler
This commit is contained in:
parent
1a392b69f6
commit
591d3ac848
|
@ -547,7 +547,7 @@ class Dataset(_Dataset):
|
|||
|
||||
if self.sampler_type == "path":
|
||||
if self.sampler_order == "duration" and cfg.dataset.sample_max_duration_batch > 0:
|
||||
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size )
|
||||
self.sampler = BatchedOrderedSampler( self.duration_buckets, cfg.dataset.sample_max_duration_batch, cfg.hyperparameters.batch_size if training else cfg.evaluation.batch_size )
|
||||
else:
|
||||
self.sampler = OrderedSampler( len(self) )
|
||||
self.samplers = {}
|
||||
|
|
|
@ -278,7 +278,7 @@ class AR_NAR(Base):
|
|||
stop_token = self.stop_token
|
||||
task_list = [ "tts" for _ in range(batch_size) ]
|
||||
|
||||
recurrent_state = [] if cfg.inference.recurrent_forward else None
|
||||
state = None
|
||||
mirostat = [
|
||||
{"n": 1024, "tau": sampling_mirostat_tau, "eta": sampling_mirostat_eta, "max_surprise": sampling_mirostat_eta * 2, "error_surprise": 0, "running_total_surprise": 0}
|
||||
] * batch_size if sampling_mirostat_tau > 0.0 else None
|
||||
|
@ -308,15 +308,15 @@ class AR_NAR(Base):
|
|||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||
)
|
||||
|
||||
if recurrent_state is not None:
|
||||
logits, recurrent_state = super().forward(
|
||||
if state is not None:
|
||||
logits, state = super().forward(
|
||||
inputs=inputs,
|
||||
state=recurrent_state,
|
||||
state=state,
|
||||
)
|
||||
else:
|
||||
logits = super().forward(
|
||||
inputs=inputs,
|
||||
state=recurrent_state,
|
||||
state=state,
|
||||
)
|
||||
|
||||
r = super().sample(
|
||||
|
|
Loading…
Reference in New Issue
Block a user