Support inference across batches, support inference on cpu, checkpoint
This is a checkpoint of a set of long tests with reduced-complexity networks. Some takeaways: 1) A full GAN using the resnet discriminator does appear to converge, but the quality is capped. 2) Likewise, a combination GAN/feature loss does not converge. The feature loss is optimized but the model appears unable to fight the discriminator, so the G-loss steadily increases. Going forwards, I want to try some bigger models. In particular, I want to change the generator to increase complexity and capacity. I also want to add skip connections between the disc and generator.
This commit is contained in:
parent
9c7debe75c
commit
44b89330c2
|
@ -3,6 +3,7 @@
|
||||||
<component name="NewModuleRootManager">
|
<component name="NewModuleRootManager">
|
||||||
<content url="file://$MODULE_DIR$">
|
<content url="file://$MODULE_DIR$">
|
||||||
<sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" />
|
<sourceFolder url="file://$MODULE_DIR$/codes" isTestSource="false" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/codes/temp" />
|
||||||
<excludeFolder url="file://$MODULE_DIR$/datasets" />
|
<excludeFolder url="file://$MODULE_DIR$/datasets" />
|
||||||
<excludeFolder url="file://$MODULE_DIR$/experiments" />
|
<excludeFolder url="file://$MODULE_DIR$/experiments" />
|
||||||
<excludeFolder url="file://$MODULE_DIR$/results" />
|
<excludeFolder url="file://$MODULE_DIR$/results" />
|
||||||
|
|
|
@ -21,7 +21,8 @@ def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
|
||||||
num_workers=num_workers, sampler=sampler, drop_last=True,
|
num_workers=num_workers, sampler=sampler, drop_last=True,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
else:
|
else:
|
||||||
return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0,
|
batch_size = dataset_opt['batch_size'] or 1
|
||||||
|
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=0,
|
||||||
pin_memory=False)
|
pin_memory=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -158,10 +158,7 @@ class SRGANModel(BaseModel):
|
||||||
|
|
||||||
self.fake_H = []
|
self.fake_H = []
|
||||||
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
for var_L, var_H, var_ref, pix in zip(self.var_L, self.var_H, self.var_ref, self.pix):
|
||||||
if step > self.D_init_iters:
|
|
||||||
fake_H = self.netG(var_L)
|
fake_H = self.netG(var_L)
|
||||||
else:
|
|
||||||
fake_H = pix
|
|
||||||
self.fake_H.append(fake_H.detach())
|
self.fake_H.append(fake_H.detach())
|
||||||
|
|
||||||
l_g_total = 0
|
l_g_total = 0
|
||||||
|
|
|
@ -26,7 +26,7 @@ class SRModel(BaseModel):
|
||||||
self.netG = networks.define_G(opt).to(self.device)
|
self.netG = networks.define_G(opt).to(self.device)
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
|
||||||
else:
|
elif opt['gpu_ids'] is not None:
|
||||||
self.netG = DataParallel(self.netG)
|
self.netG = DataParallel(self.netG)
|
||||||
# print network
|
# print network
|
||||||
self.print_network()
|
self.print_network()
|
||||||
|
|
|
@ -10,6 +10,7 @@ def parse(opt_path, is_train=True):
|
||||||
with open(opt_path, mode='r') as f:
|
with open(opt_path, mode='r') as f:
|
||||||
opt = yaml.load(f, Loader=Loader)
|
opt = yaml.load(f, Loader=Loader)
|
||||||
# export CUDA_VISIBLE_DEVICES
|
# export CUDA_VISIBLE_DEVICES
|
||||||
|
if 'gpu_ids' in opt.keys():
|
||||||
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
|
||||||
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
|
||||||
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
|
||||||
|
|
|
@ -4,23 +4,23 @@ model: sr
|
||||||
distortion: sr
|
distortion: sr
|
||||||
scale: 4
|
scale: 4
|
||||||
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels
|
||||||
gpu_ids: [0]
|
#gpu_ids: [0]
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
test_1: # the 1st test dataset
|
test_1: # the 1st test dataset
|
||||||
name: set5
|
name: set5
|
||||||
mode: LQ
|
mode: LQ
|
||||||
dataroot_LQ: ../datasets/upsample_tests
|
batch_size: 1
|
||||||
|
dataroot_LQ: E:\4k6k\datasets\adrianna\full_extract
|
||||||
|
|
||||||
#### network structures
|
#### network structures
|
||||||
network_G:
|
network_G:
|
||||||
which_model_G: RRDBNet
|
which_model_G: RRDBNet
|
||||||
in_nc: 3
|
in_nc: 3
|
||||||
out_nc: 3
|
out_nc: 3
|
||||||
nf: 64
|
nf: 48
|
||||||
nb: 23
|
nb: 23
|
||||||
upscale: 4
|
|
||||||
|
|
||||||
#### path
|
#### path
|
||||||
path:
|
path:
|
||||||
pretrain_model_G: ../experiments/ESRGANx4_blacked_for_adrianna/models/19500_G.pth
|
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
|
|
@ -5,7 +5,7 @@ model: srgan
|
||||||
distortion: sr
|
distortion: sr
|
||||||
scale: 4
|
scale: 4
|
||||||
gpu_ids: [0]
|
gpu_ids: [0]
|
||||||
amp_opt_level: O0
|
amp_opt_level: O1
|
||||||
|
|
||||||
#### datasets
|
#### datasets
|
||||||
datasets:
|
datasets:
|
||||||
|
@ -17,7 +17,7 @@ datasets:
|
||||||
doCrop: false
|
doCrop: false
|
||||||
use_shuffle: true
|
use_shuffle: true
|
||||||
n_workers: 12 # per GPU
|
n_workers: 12 # per GPU
|
||||||
batch_size: 64
|
batch_size: 40
|
||||||
target_size: 256
|
target_size: 256
|
||||||
color: RGB
|
color: RGB
|
||||||
val:
|
val:
|
||||||
|
@ -40,18 +40,18 @@ network_D:
|
||||||
|
|
||||||
#### path
|
#### path
|
||||||
path:
|
path:
|
||||||
pretrain_model_G: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_G.pth
|
pretrain_model_G: ../experiments/rrdb_blacked_gan_g.pth
|
||||||
pretrain_model_D: ../experiments/blacked_fix_and_upconv_gan_only/models/7000_D.pth
|
pretrain_model_D: ~
|
||||||
strict_load: true
|
strict_load: true
|
||||||
resume_state: ~
|
resume_state: ../experiments/blacked_fix_and_upconv/training_state/9500.state
|
||||||
|
|
||||||
#### training settings: learning rate scheme, loss
|
#### training settings: learning rate scheme, loss
|
||||||
train:
|
train:
|
||||||
lr_G: !!float 5e-5
|
lr_G: !!float 1e-5
|
||||||
weight_decay_G: 0
|
weight_decay_G: 0
|
||||||
beta1_G: 0.9
|
beta1_G: 0.9
|
||||||
beta2_G: 0.99
|
beta2_G: 0.99
|
||||||
lr_D: !!float 8e-5
|
lr_D: !!float 4e-5
|
||||||
weight_decay_D: 0
|
weight_decay_D: 0
|
||||||
beta1_D: 0.9
|
beta1_D: 0.9
|
||||||
beta2_D: 0.99
|
beta2_D: 0.99
|
||||||
|
@ -61,20 +61,20 @@ train:
|
||||||
warmup_iter: -1 # no warm up
|
warmup_iter: -1 # no warm up
|
||||||
lr_steps: [5000, 20000, 40000, 60000]
|
lr_steps: [5000, 20000, 40000, 60000]
|
||||||
lr_gamma: 0.5
|
lr_gamma: 0.5
|
||||||
mega_batch_factor: 8
|
mega_batch_factor: 4
|
||||||
|
|
||||||
pixel_criterion: l1
|
pixel_criterion: l1
|
||||||
pixel_weight: !!float 1e-2
|
pixel_weight: !!float 1e-2
|
||||||
feature_criterion: l1
|
feature_criterion: l1
|
||||||
feature_weight: 0
|
feature_weight: 0
|
||||||
feature_weight_decay: .9
|
feature_weight_decay: .9
|
||||||
feature_weight_decay_steps: 500
|
feature_weight_decay_steps: 501
|
||||||
feature_weight_minimum: .1
|
feature_weight_minimum: 0
|
||||||
gan_type: gan # gan | ragan
|
gan_type: gan # gan | ragan
|
||||||
gan_weight: 1
|
gan_weight: 1
|
||||||
|
|
||||||
D_update_ratio: 1
|
D_update_ratio: 1
|
||||||
D_init_iters: -1
|
D_init_iters: 997
|
||||||
|
|
||||||
manual_seed: 10
|
manual_seed: 10
|
||||||
val_freq: !!float 5e2
|
val_freq: !!float 5e2
|
||||||
|
|
|
@ -15,7 +15,7 @@ if __name__ == "__main__":
|
||||||
#### options
|
#### options
|
||||||
want_just_images = True
|
want_just_images = True
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_vrp.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='options/test/test_ESRGAN_adrianna_full.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user