Misc changes

This commit is contained in:
James Betker 2020-07-01 11:28:23 -06:00
parent 604763be68
commit c0bb123504
3 changed files with 7 additions and 5 deletions

View File

@ -34,6 +34,7 @@ class TorchCustomTrace:
# the backwards pass.
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
mod = self.modules[mod_id]
'''
for li in inputs:
if type(li) == torch.Tensor:
li = [li]
@ -48,7 +49,8 @@ class TorchCustomTrace:
self.module_map_by_inputs[o.data_ptr()].append(mod)
else:
self.module_map_by_inputs[o.data_ptr()] = [mod]
# print(trace, [i.data_ptr() for i in inputs], [o.data_ptr() for o in outputs])
'''
print(trace)
def mem_backward_hook(self, inputs, outputs, op):
if len(inputs) == 0:
@ -90,7 +92,7 @@ class TorchCustomTrace:
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/debug.yml')
opt = option.parse(parser.parse_args().opt, is_train=False)
opt = option.dict_to_nonedict(opt)

View File

@ -1,8 +1,8 @@
from torch.utils.tensorboard import SummaryWriter
if __name__ == "__main__":
writer = SummaryWriter("../experiments/train_div2k_feat_rg2_more_stuff")
f = open("../experiments/train_div2k_feat_rg2_more_stuff/console_output")
writer = SummaryWriter("../experiments/train_div2k_feat_nsgen_r3/recovered_tb")
f = open("../experiments/train_div2k_feat_nsgen_r3/console.txt", encoding="utf8")
console = f.readlines()
search_terms = [
("iter", ", iter: ", ", lr:"),

View File

@ -31,7 +31,7 @@ def init_dist(backend='nccl', **kwargs):
def main():
#### options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_residualgenerator_fast_specific.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_feat_resgen2_lr.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)