abstractify

This commit is contained in:
James Betker 2022-05-02 00:11:26 -06:00
parent ab219fbefb
commit e402089556
2 changed files with 2 additions and 2 deletions

View File

@ -46,7 +46,7 @@ def main():
dataset_opt['n_workers'] = 0 # Force num_workers=0 to make dataloader work in process. dataset_opt['n_workers'] = 0 # Force num_workers=0 to make dataloader work in process.
train_loader = create_dataloader(train_set, dataset_opt, opt, None) train_loader = create_dataloader(train_set, dataset_opt, opt, None)
if rank <= 0: if rank <= 0:
print('Number of train images: {:,d}, iters: {:,d}'.format( print('Number of training data elements: {:,d}, iters: {:,d}'.format(
len(train_set), train_size)) len(train_set), train_size))
assert train_loader is not None assert train_loader is not None

View File

@ -121,7 +121,7 @@ class Trainer:
shuffle = True shuffle = True
self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle) self.train_loader = create_dataloader(self.train_set, dataset_opt, opt, self.train_sampler, collate_fn=collate_fn, shuffle=shuffle)
if self.rank <= 0: if self.rank <= 0:
self.logger.info('Number of train images: {:,d}, iters: {:,d}'.format( self.logger.info('Number of training data elements: {:,d}, iters: {:,d}'.format(
len(self.train_set), train_size)) len(self.train_set), train_size))
self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format( self.logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
self.total_epochs, total_iters)) self.total_epochs, total_iters))