Merge remote-tracking branch 'origin/gan_lab' into gan_lab
This commit is contained in:
commit
3cc56cd00b
|
@ -34,7 +34,7 @@ class BaseUnsupervisedImageDataset(data.Dataset):
|
||||||
for c in chunks:
|
for c in chunks:
|
||||||
c.reload(opt)
|
c.reload(opt)
|
||||||
else:
|
else:
|
||||||
chunks = [ChunkWithReference(opt, d) for d in os.scandir(path) if d.is_dir()]
|
chunks = [ChunkWithReference(opt, d) for d in sorted(os.scandir(path), key=lambda e: e.name) if d.is_dir()]
|
||||||
# Prune out chunks that have no images
|
# Prune out chunks that have no images
|
||||||
res = []
|
res = []
|
||||||
for c in chunks:
|
for c in chunks:
|
||||||
|
|
|
@ -23,6 +23,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
||||||
frames_needed -= 1
|
frames_needed -= 1
|
||||||
search_idx -= 1
|
search_idx -= 1
|
||||||
else:
|
else:
|
||||||
|
search_idx += 1
|
||||||
break
|
break
|
||||||
|
|
||||||
# Now build num_frames starting from search_idx.
|
# Now build num_frames starting from search_idx.
|
||||||
|
@ -62,7 +63,7 @@ class MultiFrameDataset(BaseUnsupervisedImageDataset):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
opt = {
|
opt = {
|
||||||
'name': 'amalgam',
|
'name': 'amalgam',
|
||||||
'paths': ['F:\\4k6k\\datasets\\ns_images\\vixen\\full_video_256_tiled_with_ref'],
|
'paths': ['/content/fullvideo_256_tiled_test'],
|
||||||
'weights': [1],
|
'weights': [1],
|
||||||
'target_size': 128,
|
'target_size': 128,
|
||||||
'force_multiple': 32,
|
'force_multiple': 32,
|
||||||
|
@ -77,13 +78,14 @@ if __name__ == '__main__':
|
||||||
ds = MultiFrameDataset(opt)
|
ds = MultiFrameDataset(opt)
|
||||||
import os
|
import os
|
||||||
os.makedirs("debug", exist_ok=True)
|
os.makedirs("debug", exist_ok=True)
|
||||||
for i in range(100000, len(ds)):
|
for i in [3]:
|
||||||
import random
|
import random
|
||||||
o = ds[random.randint(0, 1000000)]
|
o = ds[i]
|
||||||
k = 'GT'
|
k = 'GT'
|
||||||
v = o[k]
|
v = o[k]
|
||||||
if 'path' not in k and 'center' not in k:
|
if 'path' not in k and 'center' not in k:
|
||||||
fr, f, h, w = v.shape
|
fr, f, h, w = v.shape
|
||||||
for j in range(fr):
|
for j in range(fr):
|
||||||
import torchvision
|
import torchvision
|
||||||
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i.png" % (i, k, j))
|
base=osp.basename(o["GT_path"])
|
||||||
|
torchvision.utils.save_image(v[j].unsqueeze(0), "debug/%i_%s_%i__%s.png" % (i, k, j, base))
|
||||||
|
|
|
@ -3,7 +3,8 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from apex import amp
|
from apex import amp
|
||||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
|
from torch.nn.parallel import DataParallel
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import models.lr_scheduler as lr_scheduler
|
import models.lr_scheduler as lr_scheduler
|
||||||
|
@ -107,9 +108,7 @@ class ExtensibleTrainer(BaseModel):
|
||||||
dnets = []
|
dnets = []
|
||||||
for anet in amp_nets:
|
for anet in amp_nets:
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
dnet = DistributedDataParallel(anet,
|
dnet = DistributedDataParallel(anet, delay_allreduce=True)
|
||||||
device_ids=[torch.cuda.current_device()],
|
|
||||||
find_unused_parameters=False)
|
|
||||||
else:
|
else:
|
||||||
dnet = DataParallel(anet)
|
dnet = DataParallel(anet)
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
|
@ -299,6 +298,8 @@ class ExtensibleTrainer(BaseModel):
|
||||||
def load(self):
|
def load(self):
|
||||||
for netdict in [self.netsG, self.netsD]:
|
for netdict in [self.netsG, self.netsD]:
|
||||||
for name, net in netdict.items():
|
for name, net in netdict.items():
|
||||||
|
if not self.opt['networks'][name]['trainable']:
|
||||||
|
continue
|
||||||
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
load_path = self.opt['path']['pretrain_model_%s' % (name,)]
|
||||||
if load_path is not None:
|
if load_path is not None:
|
||||||
if self.rank <= 0:
|
if self.rank <= 0:
|
||||||
|
|
|
@ -476,6 +476,10 @@ class StackedSwitchGenerator5Layer(nn.Module):
|
||||||
|
|
||||||
def update_for_step(self, step, experiments_path='.'):
|
def update_for_step(self, step, experiments_path='.'):
|
||||||
if self.attentions:
|
if self.attentions:
|
||||||
|
# All-reduce the attention norm.
|
||||||
|
for sw in self.switches:
|
||||||
|
sw.switch.reduce_norm_params()
|
||||||
|
|
||||||
temp = max(1, 1 + self.init_temperature *
|
temp = max(1, 1 + self.init_temperature *
|
||||||
(self.final_temperature_step - step) / self.final_temperature_step)
|
(self.final_temperature_step - step) / self.final_temperature_step)
|
||||||
self.set_temperature(temp)
|
self.set_temperature(temp)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from apex.parallel import DistributedDataParallel
|
||||||
import utils.util
|
import utils.util
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ import torch.nn.functional as F
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
def create_teco_loss(opt, env):
|
def create_teco_loss(opt, env):
|
||||||
type = opt['type']
|
type = opt['type']
|
||||||
|
@ -123,6 +124,8 @@ class RecurrentImageGeneratorSequenceInjector(Injector):
|
||||||
return {self.output: results}
|
return {self.output: results}
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, gen_input, it):
|
def produce_teco_visual_debugs(self, gen_input, it):
|
||||||
|
if dist.get_rank() > 0:
|
||||||
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_geninput", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
torchvision.utils.save_image(gen_input[:, :3], osp.join(base_path, "%s_img.png" % (it,)))
|
||||||
|
@ -192,6 +195,8 @@ class TecoGanLoss(ConfigurableLoss):
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, sext, lbl, it):
|
def produce_teco_visual_debugs(self, sext, lbl, it):
|
||||||
|
if dist.get_rank() > 0:
|
||||||
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_sext", str(self.env['step']), lbl)
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
lbls = ['img_a', 'img_b', 'img_c', 'flow_a', 'flow_b', 'flow_c']
|
||||||
|
@ -220,6 +225,8 @@ class PingPongLoss(ConfigurableLoss):
|
||||||
return l_total
|
return l_total
|
||||||
|
|
||||||
def produce_teco_visual_debugs(self, imglist):
|
def produce_teco_visual_debugs(self, imglist):
|
||||||
|
if dist.get_rank() > 0:
|
||||||
|
return
|
||||||
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
base_path = osp.join(self.env['base_path'], "..", "visual_dbg", "teco_pingpong", str(self.env['step']))
|
||||||
os.makedirs(base_path, exist_ok=True)
|
os.makedirs(base_path, exist_ok=True)
|
||||||
assert isinstance(imglist, list)
|
assert isinstance(imglist, list)
|
||||||
|
|
|
@ -1 +1 @@
|
||||||
Subproject commit 004dda04e39e91c109fdec87b8fb9524f653f6d6
|
Subproject commit a8c13a86ef22c5bd4e793e164fc5ebfceaad4b4b
|
Loading…
Reference in New Issue
Block a user