This commit is contained in:
James Betker 2020-07-06 20:44:07 -06:00
parent 6beefa6d0c
commit 60c6352843
2 changed files with 12 additions and 10 deletions

View File

@ -1,12 +1,13 @@
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
if __name__ == "__main__": if __name__ == "__main__":
writer = SummaryWriter("../experiments/train_div2k_feat_nsgen_r3/recovered_tb") writer = SummaryWriter("../experiments/recovered_tb")
f = open("../experiments/train_div2k_feat_nsgen_r3/console.txt", encoding="utf8") f = open("../experiments/recovered_tb.txt", encoding="utf8")
console = f.readlines() console = f.readlines()
search_terms = [ search_terms = [
("iter", ", iter: ", ", lr:"), ("iter", ", iter: ", ", lr:"),
("l_g_total", " l_g_total: ", " switch_temperature:") ("l_g_total", " l_g_total: ", " switch_temperature:"),
("l_d_fake", "l_d_fake: ", " D_fake:")
] ]
iter = 0 iter = 0
for line in console: for line in console:

View File

@ -1,5 +1,6 @@
import torch import torch
from torch import nn from torch import nn
import models.archs.SRG1_arch as srg1
import models.archs.SwitchedResidualGenerator_arch as srg import models.archs.SwitchedResidualGenerator_arch as srg
import models.archs.NestedSwitchGenerator as nsg import models.archs.NestedSwitchGenerator as nsg
import functools import functools
@ -96,15 +97,15 @@ if __name__ == "__main__":
torch.randn(1, 3, 64, 64), torch.randn(1, 3, 64, 64),
device='cuda') device='cuda')
''' '''
test_stability(functools.partial(srg.ConfigurableSwitchedResidualGenerator2, test_stability(functools.partial(srg1.ConfigurableSwitchedResidualGenerator,
switch_filters=[16,16,16,16,16], switch_filters=[32,32,32,32],
switch_growths=[32,32,32,32,32], switch_growths=[16,16,16,16],
switch_reductions=[1,1,1,1,1], switch_reductions=[4,3,2,1],
switch_processing_layers=[5,5,5,5,5], switch_processing_layers=[3,3,4,5],
trans_counts=[8,8,8,8,8], trans_counts=[16,16,16,16,16],
trans_kernel_sizes=[3,3,3,3,3], trans_kernel_sizes=[3,3,3,3,3],
trans_layers=[3,3,3,3,3], trans_layers=[3,3,3,3,3],
transformation_filters=64, trans_filters_mid=[24,24,24,24,24],
initial_temp=10), initial_temp=10),
torch.randn(1, 3, 64, 64), torch.randn(1, 3, 64, 64),
device='cuda') device='cuda')