Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-10-08 16:12:05 -06:00
commit 3cc56cd00b
7 changed files with 28 additions and 14 deletions

View File

@ -34,7 +34,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
for c in chunks: for c in chunks:
c.reload(opt) c.reload(opt)
else: else:
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()] chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
# Prune out chunks that have no images # Prune out chunks that have no images
res = [] res = []
for c in chunks: for c in chunks:

View File

@ -23,6 +23,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
frames_needed -= 1 frames_needed -= 1
search_idx -= 1 search_idx -= 1
else: else:
search_idx += 1
break break
# Now build num_frames starting from search_idx. # Now build num_frames starting from search_idx.
@ -62,7 +63,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
if __name__ == '__main__': if __name__ == '__main__':
opt = { opt = {
'name': 'amalgam', 'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_256_tiled_with_ref'], 'paths': ['/content/fullvideo_256_tiled_test'],
'weights': [1], 'weights': [1],
'target_size': 128, 'target_size': 128,
'force_multiple': 32, 'force_multiple': 32,
@ -77,13 +78,14 @@ if __name__ == '__main__':
ds = MultiFrameDataset(opt) ds = MultiFrameDataset(opt)
import os import os
os.makedirs("debug", exist_ok=True) os.makedirs("debug", exist_ok=True)
for i in range(100000, len(ds)): for i in [3]:
import random import random
o = ds[random.randint(0, 1000000)] o = ds[i]
k = 'GT' k = 'GT'
v = o[k] v = o[k]
if 'path' not in k and 'center' not in k: if 'path' not in k and 'center' not in k:
fr, f, h, w = v.shape fr, f, h, w = v.shape
for j in range(fr): for j in range(fr):
import torchvision import torchvision
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j)) base=osp.basename(o["GT_path"])
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base))

View File

@ -3,7 +3,8 @@ import os
import torch import torch
from apex import amp from apex import amp
from torch.nn.parallel import DataParallel, DistributedDataParallel from apex.parallel import DistributedDataParallel
from torch.nn.parallel import DataParallel
import torch.nn as nn import torch.nn as nn
import models.lr_scheduler as lr_scheduler import models.lr_scheduler as lr_scheduler
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
dnets = [] dnets = []
for anet in amp_nets: for anet in amp_nets:
if opt['dist']: if opt['dist']:
dnet = DistributedDataParallel(anet, dnet = DistributedDataParallel(anet, delay_allreduce=True)
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
else: else:
dnet = DataParallel(anet) dnet = DataParallel(anet)
if self.is_train: if self.is_train:
@ -299,6 +298,8 @@ class ExtensibleTrainer(BaseModel):
def load(self): def load(self):
for netdict in [self.netsG, self.netsD]: for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items(): for name, net in netdict.items():
if not self.opt['networks'][name]['trainable']:
continue
load_path = self.opt['path']['pretrain_model_%s' % (name,)] load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is not None: if load_path is not None:
if self.rank <= 0: if self.rank <= 0:

View File

@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
def update_for_step(self, step, experiments_path='.'): def update_for_step(self, step, experiments_path='.'):
if self.attentions: if self.attentions:
# All-reduce the attention norm.
for sw in self.switches:
sw.switch.reduce_norm_params()
temp = max(1, 1 + self.init_temperature * temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step) (self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp) self.set_temperature(temp)

View File

@ -2,7 +2,7 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel from apex.parallel import DistributedDataParallel
import utils.util import utils.util
from apex import amp from apex import amp

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
import os import os
import os.path as osp import os.path as osp
import torchvision import torchvision
import torch.distributed as dist
def create_teco_loss(opt, env): def create_teco_loss(opt, env):
type = opt['type'] type = opt['type']
@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return {self.output: results} return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it): def produce_teco_visual_debugs(self, gen_input, it):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,))) torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss):
return l_total return l_total
def produce_teco_visual_debugs(self, sext, lbl, it): def produce_teco_visual_debugs(self, sext, lbl, it):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c'] lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
@ -220,6 +225,8 @@ class PingPongLoss(ConfigurableLoss):
return l_total return l_total
def produce_teco_visual_debugs(self, imglist): def produce_teco_visual_debugs(self, imglist):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step'])) base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True) os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list) assert isinstance(imglist, list)

@ -1 +1 @@
Subproject commit 004dda04e39e91c109fdec87b8fb9524f653f6d6 Subproject commit a8c13a86ef22c5bd4e793e164fc5ebfceaad4b4b