added a flag (thanks gannybal)
This commit is contained in:
parent
0f04206aa2
commit
71cc43e65c
|
@ -20,7 +20,7 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None, collate_fn=N
|
||||||
batch_size = dataset_opt['batch_size']
|
batch_size = dataset_opt['batch_size']
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||||
num_workers=num_workers, sampler=sampler, drop_last=True,
|
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||||
pin_memory=pin_memory, collate_fn=collate_fn)
|
pin_memory=pin_memory, collate_fn=collate_fn, persistent_workers=True)
|
||||||
else:
|
else:
|
||||||
batch_size = dataset_opt['batch_size'] or 1
|
batch_size = dataset_opt['batch_size'] or 1
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user