Merge remote-tracking branch 'origin/gan_lab' into gan_lab

This commit is contained in:
James Betker 2020-10-08 16:12:05 -06:00
commit 3cc56cd00b
7 changed files with 28 additions and 14 deletions

View File

@ -34,7 +34,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
for c in chunks:
c.reload(opt)
else:
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
# Prune out chunks that have no images
res = []
for c in chunks:

View File

@ -23,6 +23,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
frames_needed -= 1
search_idx -= 1
else:
search_idx += 1
break
# Now build num_frames starting from search_idx.
@ -62,7 +63,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
if __name__ == '__main__':
opt = {
'name': 'amalgam',
'paths': ['F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_256_tiled_with_ref'],
'paths': ['/content/fullvideo_256_tiled_test'],
'weights': [1],
'target_size': 128,
'force_multiple': 32,
@ -77,13 +78,14 @@ if __name__ == '__main__':
ds = MultiFrameDataset(opt)
import os
os.makedirs("debug", exist_ok=True)
for i in range(100000, len(ds)):
for i in [3]:
import random
o = ds[random.randint(0, 1000000)]
o = ds[i]
k = 'GT'
v = o[k]
if 'path' not in k and 'center' not in k:
fr, f, h, w = v.shape
for j in range(fr):
import torchvision
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
base=osp.basename(o["GT_path"])
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base))

View File

@ -3,7 +3,8 @@ import os
import torch
from apex import amp
from torch.nn.parallel import DataParallel, DistributedDataParallel
from apex.parallel import DistributedDataParallel
from torch.nn.parallel import DataParallel
import torch.nn as nn
import models.lr_scheduler as lr_scheduler
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
dnets = []
for anet in amp_nets:
if opt['dist']:
dnet = DistributedDataParallel(anet,
device_ids=[torch.cuda.current_device()],
find_unused_parameters=False)
dnet = DistributedDataParallel(anet, delay_allreduce=True)
else:
dnet = DataParallel(anet)
if self.is_train:
@ -299,6 +298,8 @@ class ExtensibleTrainer(BaseModel):
def load(self):
for netdict in [self.netsG, self.netsD]:
for name, net in netdict.items():
if not self.opt['networks'][name]['trainable']:
continue
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
if load_path is not None:
if self.rank <= 0:
@ -313,4 +314,4 @@ class ExtensibleTrainer(BaseModel):
def force_restore_swapout(self):
# Legacy method. Do nothing.
pass
pass

View File

@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
def update_for_step(self, step, experiments_path='.'):
if self.attentions:
# All-reduce the attention norm.
for sw in self.switches:
sw.switch.reduce_norm_params()
temp = max(1, 1 + self.init_temperature *
(self.final_temperature_step - step) / self.final_temperature_step)
self.set_temperature(temp)
@ -496,4 +500,4 @@ class StackedSwitchGenerator5Layer(nn.Module):
for i in range(len(means)):
val["switch_%i_specificity" % (i,)] = means[i]
val["switch_%i_histogram" % (i,)] = hists[i]
return val
return val

View File

@ -2,7 +2,7 @@ import os
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
from apex.parallel import DistributedDataParallel
import utils.util
from apex import amp

View File

@ -7,6 +7,7 @@ import torch.nn.functional as F
import os
import os.path as osp
import torchvision
import torch.distributed as dist
def create_teco_loss(opt, env):
type = opt['type']
@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
return {self.output: results}
def produce_teco_visual_debugs(self, gen_input, it):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss):
return l_total
def produce_teco_visual_debugs(self, sext, lbl, it):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
os.makedirs(base_path, exist_ok=True)
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
@ -220,8 +225,10 @@ class PingPongLoss(ConfigurableLoss):
return l_total
def produce_teco_visual_debugs(self, imglist):
if dist.get_rank() > 0:
return
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
os.makedirs(base_path, exist_ok=True)
assert isinstance(imglist, list)
for i, img in enumerate(imglist):
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))

@ -1 +1 @@
Subproject commit 004dda04e39e91c109fdec87b8fb9524f653f6d6
Subproject commit a8c13a86ef22c5bd4e793e164fc5ebfceaad4b4b