|
|
|
@ -12,7 +12,7 @@ from data.data_sampler import DistIterSampler
|
|
|
|
|
from trainer.eval.evaluator import create_evaluator
|
|
|
|
|
|
|
|
|
|
from utils import util, options as option
|
|
|
|
|
from data import create_dataloader, create_dataset
|
|
|
|
|
from data import create_dataloader, create_dataset, get_dataset_debugger
|
|
|
|
|
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
|
|
|
|
from time import time
|
|
|
|
|
from datetime import datetime
|
|
|
|
@ -110,6 +110,9 @@ class Trainer:
|
|
|
|
|
for phase, dataset_opt in opt['datasets'].items():
|
|
|
|
|
if phase == 'train':
|
|
|
|
|
self.train_set, collate_fn = create_dataset(dataset_opt, return_collate=True)
|
|
|
|
|
self.dataset_debugger = get_dataset_debugger(dataset_opt)
|
|
|
|
|
if self.dataset_debugger is not None and resume_state is not None:
|
|
|
|
|
self.dataset_debugger.load_state(opt_get(resume_state, ['dataset_debugger_state'], {}))
|
|
|
|
|
train_size = int(math.ceil(len(self.train_set) / dataset_opt['batch_size']))
|
|
|
|
|
total_iters = int(opt['train']['niter'])
|
|
|
|
|
self.total_epochs = int(math.ceil(total_iters / train_size))
|
|
|
|
@ -187,8 +190,12 @@ class Trainer:
|
|
|
|
|
_t = time()
|
|
|
|
|
|
|
|
|
|
#### log
|
|
|
|
|
if self.dataset_debugger is not None:
|
|
|
|
|
self.dataset_debugger.update(train_data)
|
|
|
|
|
if self.current_step % opt['logger']['print_freq'] == 0 and self.rank <= 0:
|
|
|
|
|
logs = self.model.get_current_log(self.current_step)
|
|
|
|
|
if self.dataset_debugger is not None:
|
|
|
|
|
logs.update(self.dataset_debugger.get_debugging_map())
|
|
|
|
|
message = '[epoch:{:3d}, iter:{:8,d}, lr:('.format(self.epoch, self.current_step)
|
|
|
|
|
for v in self.model.get_current_learning_rate():
|
|
|
|
|
message += '{:.3e},'.format(v)
|
|
|
|
@ -213,7 +220,10 @@ class Trainer:
|
|
|
|
|
if self.rank <= 0:
|
|
|
|
|
self.logger.info('Saving models and training states.')
|
|
|
|
|
self.model.save(self.current_step)
|
|
|
|
|
self.model.save_training_state(self.epoch, self.current_step)
|
|
|
|
|
state = {'epoch': self.epoch, 'iter': self.current_step}
|
|
|
|
|
if self.dataset_debugger is not None:
|
|
|
|
|
state['dataset_debugger_state'] = self.dataset_debugger.get_state()
|
|
|
|
|
self.model.save_training_state(state)
|
|
|
|
|
if 'alt_path' in opt['path'].keys():
|
|
|
|
|
import shutil
|
|
|
|
|
print("Synchronizing tb_logger to alt_path..")
|
|
|
|
|