diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 30a7b46b..68fca1e6 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -30,7 +30,9 @@ class ExtensibleTrainer(BaseModel): self.env = {'device': self.device, 'rank': self.rank, 'opt': opt, - 'step': 0} + 'step': 0, + 'dist': opt['dist'] + } if opt['path']['models'] is not None: self.env['base_path'] = os.path.join(opt['path']['models']) diff --git a/codes/models/archs/SPSR_arch.py b/codes/models/archs/SPSR_arch.py index fdef6cd3..9d2ec8ea 100644 --- a/codes/models/archs/SPSR_arch.py +++ b/codes/models/archs/SPSR_arch.py @@ -94,12 +94,12 @@ class SPSRNet(nn.Module): n_upscale = int(math.log(upscale, 2)) - self.scale=n_upscale + self.scale=upscale if upscale == 3: n_upscale = 1 fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=7, norm=False, activation=False) - self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=7, norm=False, activation=False) + self.ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=7, norm=False, activation=False) self.join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) rb_blocks = [RRDB(nf) for _ in range(nb)] @@ -118,7 +118,7 @@ class SPSRNet(nn.Module): *upsampler, self.HR_conv0_new) self.b_fea_conv = ConvGnLelu(in_nc, nf//2, kernel_size=3, norm=False, activation=False) - self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=n_upscale, kernel_size=3, norm=False, activation=False) + self.b_ref_conv = ConvGnLelu(in_nc, nf//2, stride=upscale, kernel_size=3, norm=False, activation=False) self.b_join_conv = ConvGnLelu(nf, nf, kernel_size=3, norm=False, activation=False) self.b_concat_1 = ConvGnLelu(2 * nf, nf, kernel_size=3, norm=False, activation=False) @@ -665,4 +665,4 @@ class SwitchedSpsr(nn.Module): for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] - return val \ No newline at end of file + return val diff --git a/codes/models/steps/steps.py b/codes/models/steps/steps.py index a6b570e2..d925361e 100644 --- a/codes/models/steps/steps.py +++ b/codes/models/steps/steps.py @@ -126,49 +126,48 @@ class ConfigurableStep(Module): self.env['current_step_optimizers'] = self.optimizers self.env['training'] = train - with self.get_network_for_name(self.get_networks_trained()[0]).join(): - # Inject in any extra dependencies. - for inj in self.injectors: - # Don't do injections tagged with eval unless we are not in train mode. - if train and 'eval' in inj.opt.keys() and inj.opt['eval']: - continue - # Likewise, don't do injections tagged with train unless we are not in eval. - if not train and 'train' in inj.opt.keys() and inj.opt['train']: - continue - # Don't do injections tagged with 'after' or 'before' when we are out of spec. - if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ - 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: - continue - injected = inj(local_state) - local_state.update(injected) - new_state.update(injected) + # Inject in any extra dependencies. + for inj in self.injectors: + # Don't do injections tagged with eval unless we are not in train mode. + if train and 'eval' in inj.opt.keys() and inj.opt['eval']: + continue + # Likewise, don't do injections tagged with train unless we are not in eval. + if not train and 'train' in inj.opt.keys() and inj.opt['train']: + continue + # Don't do injections tagged with 'after' or 'before' when we are out of spec. + if 'after' in inj.opt.keys() and self.env['step'] < inj.opt['after'] or \ + 'before' in inj.opt.keys() and self.env['step'] > inj.opt['before']: + continue + injected = inj(local_state) + local_state.update(injected) + new_state.update(injected) - if train and len(self.losses) > 0: - # Finally, compute the losses. - total_loss = 0 - for loss_name, loss in self.losses.items(): - # Some losses only activate after a set number of steps. For example, proto-discriminator losses can - # be very disruptive to a generator. - if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: - continue - l = loss(self.training_net, local_state) - total_loss += l * self.weights[loss_name] - # Record metrics. - if isinstance(l, torch.Tensor): - self.loss_accumulator.add_loss(loss_name, l) - for n, v in loss.extra_metrics(): - self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) - loss.clear_metrics() + if train and len(self.losses) > 0: + # Finally, compute the losses. + total_loss = 0 + for loss_name, loss in self.losses.items(): + # Some losses only activate after a set number of steps. For example, proto-discriminator losses can + # be very disruptive to a generator. + if 'after' in loss.opt.keys() and loss.opt['after'] > self.env['step']: + continue + l = loss(self.training_net, local_state) + total_loss += l * self.weights[loss_name] + # Record metrics. + if isinstance(l, torch.Tensor): + self.loss_accumulator.add_loss(loss_name, l) + for n, v in loss.extra_metrics(): + self.loss_accumulator.add_loss("%s_%s" % (loss_name, n), v) + loss.clear_metrics() - # In some cases, the loss could not be set (e.g. all losses have 'after') - if isinstance(total_loss, torch.Tensor): - self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) - # Scale the loss down by the accumulation factor. - total_loss = total_loss / self.env['mega_batch_factor'] + # In some cases, the loss could not be set (e.g. all losses have 'after') + if isinstance(total_loss, torch.Tensor): + self.loss_accumulator.add_loss("%s_total" % (self.get_training_network_name(),), total_loss) + # Scale the loss down by the accumulation factor. + total_loss = total_loss / self.env['mega_batch_factor'] - # Get dem grads! - self.scaler.scale(total_loss).backward() - self.grads_generated = True + # Get dem grads! + self.scaler.scale(total_loss).backward() + self.grads_generated = True # Detach all state variables. Within the step, gradients can flow. Once these variables leave the step # we must release the gradients. diff --git a/codes/multi_modal_train.py b/codes/multi_modal_train.py index 23d1c379..c82787fa 100644 --- a/codes/multi_modal_train.py +++ b/codes/multi_modal_train.py @@ -15,15 +15,29 @@ import yaml import train import utils.options as option from utils.util import OrderedYaml +import torch def main(master_opt, launcher): trainers = [] all_networks = {} shared_networks = [] + if launcher != 'none': + train.init_dist('nccl') for i, sub_opt in enumerate(master_opt['trainer_options']): sub_opt_parsed = option.parse(sub_opt, is_train=True) trainer = train.Trainer() + + #### distributed training settings + if launcher == 'none': # disabled distributed training + sub_opt_parsed['dist'] = False + trainer.rank = -1 + print('Disabled distributed training.') + else: + sub_opt_parsed['dist'] = True + trainer.world_size = torch.distributed.get_world_size() + trainer.rank = torch.distributed.get_rank() + trainer.init(sub_opt_parsed, launcher, all_networks) train_gen = trainer.create_training_generator(i) model = next(train_gen) @@ -44,6 +58,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_exd_imgset_chained_structured_trans_invariance.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') + parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() Loader, Dumper = OrderedYaml() diff --git a/codes/train.py b/codes/train.py index e4c9e884..852621c3 100644 --- a/codes/train.py +++ b/codes/train.py @@ -13,43 +13,26 @@ from data import create_dataloader, create_dataset from models.ExtensibleTrainer import ExtensibleTrainer from time import time -class Trainer: - def init_dist(self, backend, **kwargs): - # These packages have globals that screw with Windows, so only import them if needed. - import torch.distributed as dist - import torch.multiprocessing as mp +def init_dist(backend, **kwargs): + # These packages have globals that screw with Windows, so only import them if needed. + import torch.distributed as dist + import torch.multiprocessing as mp - """initialization for distributed training""" - if mp.get_start_method(allow_none=True) != 'spawn': - mp.set_start_method('spawn') - self.rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(self.rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) + """initialization for distributed training""" + if mp.get_start_method(allow_none=True) != 'spawn': + mp.set_start_method('spawn') + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + +class Trainer: def init(self, opt, launcher, all_networks={}): self._profile = False self.val_compute_psnr = opt['eval']['compute_psnr'] if 'compute_psnr' in opt['eval'] else True self.val_compute_fea = opt['eval']['compute_fea'] if 'compute_fea' in opt['eval'] else True - #### distributed training settings - if len(opt['gpu_ids']) == 1 and torch.cuda.device_count() > 1: - gpu = input( - 'I noticed you have multiple GPUs. Starting two jobs on the same GPU sucks. Please confirm which GPU' - 'you want to use. Press enter to use the specified one [%s]' % (opt['gpu_ids'])) - if gpu: - opt['gpu_ids'] = [int(gpu)] - if launcher == 'none': # disabled distributed training - opt['dist'] = False - self.rank = -1 - print('Disabled distributed training.') - - else: - opt['dist'] = True - self.init_dist('nccl') - world_size = torch.distributed.get_world_size() - self.rank = torch.distributed.get_rank() - #### loading resume state if exists if opt['path'].get('resume_state', None): # distributed resuming: all load into default GPU @@ -117,7 +100,7 @@ class Trainer: total_iters = int(opt['train']['niter']) self.total_epochs = int(math.ceil(total_iters / train_size)) if opt['dist']: - self.train_sampler = DistIterSampler(self.train_set, world_size, self.rank, dataset_ratio) + self.train_sampler = DistIterSampler(self.train_set, self.world_size, self.rank, dataset_ratio) self.total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) else: self.train_sampler = None @@ -288,5 +271,18 @@ if __name__ == '__main__': args = parser.parse_args() opt = option.parse(args.opt, is_train=True) trainer = Trainer() + + #### distributed training settings + if args.launcher == 'none': # disabled distributed training + opt['dist'] = False + trainer.rank = -1 + print('Disabled distributed training.') + + else: + opt['dist'] = True + init_dist('nccl') + trainer.world_size = torch.distributed.get_world_size() + trainer.rank = torch.distributed.get_rank() + trainer.init(opt, args.launcher) trainer.do_training()