Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
3cc56cd00b
|
@ -34,7 +34,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
|||
for c in chunks:
|
||||
c.reload(opt)
|
||||
else:
|
||||
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
|
||||
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
||||
# Prune out chunks that have no images
|
||||
res = []
|
||||
for c in chunks:
|
||||
|
|
|
@ -23,6 +23,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
|||
frames_needed -= 1
|
||||
search_idx -= 1
|
||||
else:
|
||||
search_idx += 1
|
||||
break
|
||||
|
||||
# Now build num_frames starting from search_idx.
|
||||
|
@ -62,7 +63,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
|||
if __name__ == '__main__':
|
||||
opt = {
|
||||
'name': 'amalgam',
|
||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_256_tiled_with_ref'],
|
||||
'paths': ['/content/fullvideo_256_tiled_test'],
|
||||
'weights': [1],
|
||||
'target_size': 128,
|
||||
'force_multiple': 32,
|
||||
|
@ -77,13 +78,14 @@ if __name__ == '__main__':
|
|||
ds = MultiFrameDataset(opt)
|
||||
import os
|
||||
os.makedirs("debug", exist_ok=True)
|
||||
for i in range(100000, len(ds)):
|
||||
for i in [3]:
|
||||
import random
|
||||
o = ds[random.randint(0, 1000000)]
|
||||
o = ds[i]
|
||||
k = 'GT'
|
||||
v = o[k]
|
||||
if 'path' not in k and 'center' not in k:
|
||||
fr, f, h, w = v.shape
|
||||
for j in range(fr):
|
||||
import torchvision
|
||||
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
||||
base=osp.basename(o["GT_path"])
|
||||
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||
|
|
|
@ -3,7 +3,8 @@ import os
|
|||
|
||||
import torch
|
||||
from apex import amp
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
from apex.parallel import DistributedDataParallel
|
||||
from torch.nn.parallel import DataParallel
|
||||
import torch.nn as nn
|
||||
|
||||
import models.lr_scheduler as lr_scheduler
|
||||
|
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
|
|||
dnets = []
|
||||
for anet in amp_nets:
|
||||
if opt['dist']:
|
||||
dnet = DistributedDataParallel(anet,
|
||||
device_ids=[torch.cuda.current_device()],
|
||||
find_unused_parameters=False)
|
||||
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||
else:
|
||||
dnet = DataParallel(anet)
|
||||
if self.is_train:
|
||||
|
@ -299,6 +298,8 @@ class ExtensibleTrainer(BaseModel):
|
|||
def load(self):
|
||||
for netdict in [self.netsG, self.netsD]:
|
||||
for name, net in netdict.items():
|
||||
if not self.opt['networks'][name]['trainable']:
|
||||
continue
|
||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||
if load_path is not None:
|
||||
if self.rank <= 0:
|
||||
|
@ -313,4 +314,4 @@ class ExtensibleTrainer(BaseModel):
|
|||
|
||||
def force_restore_swapout(self):
|
||||
# Legacy method. Do nothing.
|
||||
pass
|
||||
pass
|
||||
|
|
|
@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
|||
|
||||
def update_for_step(self, step, experiments_path='.'):
|
||||
if self.attentions:
|
||||
# All-reduce the attention norm.
|
||||
for sw in self.switches:
|
||||
sw.switch.reduce_norm_params()
|
||||
|
||||
temp = max(1, 1 + self.init_temperature *
|
||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||
self.set_temperature(temp)
|
||||
|
@ -496,4 +500,4 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
|||
for i in range(len(means)):
|
||||
val["switch_%i_specificity" % (i,)] = means[i]
|
||||
val["switch_%i_histogram" % (i,)] = hists[i]
|
||||
return val
|
||||
return val
|
||||
|
|
|
@ -2,7 +2,7 @@ import os
|
|||
from collections import OrderedDict
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from apex.parallel import DistributedDataParallel
|
||||
import utils.util
|
||||
from apex import amp
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
|||
import os
|
||||
import os.path as osp
|
||||
import torchvision
|
||||
import torch.distributed as dist
|
||||
|
||||
def create_teco_loss(opt, env):
|
||||
type = opt['type']
|
||||
|
@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
|||
return {self.output: results}
|
||||
|
||||
def produce_teco_visual_debugs(self, gen_input, it):
|
||||
if dist.get_rank() > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
||||
|
@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss):
|
|||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||
if dist.get_rank() > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||
|
@ -220,8 +225,10 @@ class PingPongLoss(ConfigurableLoss):
|
|||
return l_total
|
||||
|
||||
def produce_teco_visual_debugs(self, imglist):
|
||||
if dist.get_rank() > 0:
|
||||
return
|
||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
assert isinstance(imglist, list)
|
||||
for i, img in enumerate(imglist):
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||
torchvision.utils.save_image(img, osp.join(base_path, "%s.png" % (i, )))
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 004dda04e39e91c109fdec87b8fb9524f653f6d6
|
||||
Subproject commit a8c13a86ef22c5bd4e793e164fc5ebfceaad4b4b
|
Loading…
Reference in New Issue
Block a user