From fba29d7dcc56194eda39fbe622031f58c22eb345 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 8 Oct 2020 11:20:05 -0600 Subject: [PATCH 1/4] Move to apex distributeddataparallel and add switch all_reduce Torch's distributed_data_parallel is missing "delay_allreduce", which is necessary to get gradient checkpointing to work with recurrent models. --- codes/models/ExtensibleTrainer.py | 9 ++++----- codes/models/archs/StructuredSwitchedGenerator.py | 6 +++++- codes/models/base_model.py | 2 +- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 4f2b7636..079da6e1 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -3,7 +3,8 @@ import os import torch from apex import amp -from torch.nn.parallel import DataParallel, DistributedDataParallel +from apex.parallel import DistributedDataParallel +from torch.nn.parallel import DataParallel import torch.nn as nn import models.lr_scheduler as lr_scheduler @@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel): dnets = [] for anet in amp_nets: if opt['dist']: - dnet = DistributedDataParallel(anet, - device_ids=[torch.cuda.current_device()], - find_unused_parameters=False) + dnet = DistributedDataParallel(anet, delay_allreduce=True) else: dnet = DataParallel(anet) if self.is_train: @@ -313,4 +312,4 @@ class ExtensibleTrainer(BaseModel): def force_restore_swapout(self): # Legacy method. Do nothing. - pass \ No newline at end of file + pass diff --git a/codes/models/archs/StructuredSwitchedGenerator.py b/codes/models/archs/StructuredSwitchedGenerator.py index b5b37500..8c74b313 100644 --- a/codes/models/archs/StructuredSwitchedGenerator.py +++ b/codes/models/archs/StructuredSwitchedGenerator.py @@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module): def update_for_step(self, step, experiments_path='.'): if self.attentions: + # All-reduce the attention norm. + for sw in self.switches: + sw.switch.reduce_norm_params() + temp = max(1, 1 + self.init_temperature * (self.final_temperature_step - step) / self.final_temperature_step) self.set_temperature(temp) @@ -496,4 +500,4 @@ class StackedSwitchGenerator5Layer(nn.Module): for i in range(len(means)): val["switch_%i_specificity" % (i,)] = means[i] val["switch_%i_histogram" % (i,)] = hists[i] - return val \ No newline at end of file + return val diff --git a/codes/models/base_model.py b/codes/models/base_model.py index 0b092520..2672bfa5 100644 --- a/codes/models/base_model.py +++ b/codes/models/base_model.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict import torch import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel +from apex.parallel import DistributedDataParallel import utils.util from apex import amp From b36ba0460c9fb23ee81085ec0659c0c9f78d4257 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 8 Oct 2020 12:21:04 -0600 Subject: [PATCH 2/4] Fix multi-frame dataset OBO error --- codes/data/base_unsupervised_image_dataset.py | 2 +- codes/data/multi_frame_dataset.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/codes/data/base_unsupervised_image_dataset.py b/codes/data/base_unsupervised_image_dataset.py index 24a2c54d..1cf3db13 100644 --- a/codes/data/base_unsupervised_image_dataset.py +++ b/codes/data/base_unsupervised_image_dataset.py @@ -34,7 +34,7 @@ class BaseUnsupervisedImageDataset(data.Dataset): for c in chunks: c.reload(opt) else: - chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] + chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()] # Prune out chunks that have no images res = [] for c in chunks: diff --git a/codes/data/multi_frame_dataset.py b/codes/data/multi_frame_dataset.py index fb37ae3d..17cc43fb 100644 --- a/codes/data/multi_frame_dataset.py +++ b/codes/data/multi_frame_dataset.py @@ -23,6 +23,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): frames_needed -= 1 search_idx -= 1 else: + search_idx += 1 break # Now build num_frames starting from search_idx. @@ -62,7 +63,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset): if __name__ == '__main__': opt = { 'name': 'amalgam', - 'paths': ['F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_256_tiled_with_ref'], + 'paths': ['/content/fullvideo_256_tiled_test'], 'weights': [1], 'target_size': 128, 'force_multiple': 32, @@ -77,13 +78,14 @@ if __name__ == '__main__': ds = MultiFrameDataset(opt) import os os.makedirs("debug", exist_ok=True) - for i in range(100000, len(ds)): + for i in [3]: import random - o = ds[random.randint(0, 1000000)] + o = ds[i] k = 'GT' v = o[k] if 'path' not in k and 'center' not in k: fr, f, h, w = v.shape for j in range(fr): import torchvision - torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j)) \ No newline at end of file + base=osp.basename(o["GT_path"]) + torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base)) From 1eb516d686e4fd72df3bba05ee4f28f620fe6f56 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 8 Oct 2020 14:32:45 -0600 Subject: [PATCH 3/4] Fix more distributed bugs --- codes/models/ExtensibleTrainer.py | 2 ++ codes/models/steps/tecogan_losses.py | 9 ++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/codes/models/ExtensibleTrainer.py b/codes/models/ExtensibleTrainer.py index 079da6e1..88e34671 100644 --- a/codes/models/ExtensibleTrainer.py +++ b/codes/models/ExtensibleTrainer.py @@ -298,6 +298,8 @@ class ExtensibleTrainer(BaseModel): def load(self): for netdict in [self.netsG, self.netsD]: for name, net in netdict.items(): + if not self.opt['networks'][name]['trainable']: + continue load_path = self.opt['path']['pretrain_model_%s' % (name,)] if load_path is not None: if self.rank <= 0: diff --git a/codes/models/steps/tecogan_losses.py b/codes/models/steps/tecogan_losses.py index ea8b1017..1c033c60 100644 --- a/codes/models/steps/tecogan_losses.py +++ b/codes/models/steps/tecogan_losses.py @@ -7,6 +7,7 @@ import torch.nn.functional as F import os import os.path as osp import torchvision +import torch.distributed as dist def create_teco_loss(opt, env): type = opt['type'] @@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector): return {self.output: results} def produce_teco_visual_debugs(self, gen_input, it): + if dist.get_rank() > 0: + return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,))) @@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss): return l_total def produce_teco_visual_debugs(self, sext, lbl, it): + if dist.get_rank() > 0: + return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) os.makedirs(base_path, exist_ok=True) lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c'] @@ -220,8 +225,10 @@ class PingPongLoss(ConfigurableLoss): return l_total def produce_teco_visual_debugs(self, imglist): + if dist.get_rank() > 0: + return base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) os.makedirs(base_path, exist_ok=True) assert isinstance(imglist, list) for i, img in enumerate(imglist): - torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) \ No newline at end of file + torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, ))) From 856ef4d21de07dfa68aa0da07d5b1565b2e7e8f4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 8 Oct 2020 16:10:23 -0600 Subject: [PATCH 4/4] Update switched_conv --- codes/switched_conv | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codes/switched_conv b/codes/switched_conv index 004dda04..a8c13a86 160000 --- a/codes/switched_conv +++ b/codes/switched_conv @@ -1 +1 @@ -Subproject commit 004dda04e39e91c109fdec87b8fb9524f653f6d6 +Subproject commit a8c13a86ef22c5bd4e793e164fc5ebfceaad4b4b