Misc changes
This commit is contained in:
parent
604763be68
commit
c0bb123504
|
@ -34,6 +34,7 @@ class TorchCustomTrace:
|
||||||
# the backwards pass.
|
# the backwards pass.
|
||||||
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
|
def mem_forward_hook(self, module: torch.nn.Module, inputs, outputs, trace: str, mod_id: str):
|
||||||
mod = self.modules[mod_id]
|
mod = self.modules[mod_id]
|
||||||
|
'''
|
||||||
for li in inputs:
|
for li in inputs:
|
||||||
if type(li) == torch.Tensor:
|
if type(li) == torch.Tensor:
|
||||||
li = [li]
|
li = [li]
|
||||||
|
@ -48,7 +49,8 @@ class TorchCustomTrace:
|
||||||
self.module_map_by_inputs[o.data_ptr()].append(mod)
|
self.module_map_by_inputs[o.data_ptr()].append(mod)
|
||||||
else:
|
else:
|
||||||
self.module_map_by_inputs[o.data_ptr()] = [mod]
|
self.module_map_by_inputs[o.data_ptr()] = [mod]
|
||||||
# print(trace, [i.data_ptr() for i in inputs], [o.data_ptr() for o in outputs])
|
'''
|
||||||
|
print(trace)
|
||||||
|
|
||||||
def mem_backward_hook(self, inputs, outputs, op):
|
def mem_backward_hook(self, inputs, outputs, op):
|
||||||
if len(inputs) == 0:
|
if len(inputs) == 0:
|
||||||
|
@ -90,7 +92,7 @@ class TorchCustomTrace:
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/use_video_upsample.yml')
|
parser.add_argument('-opt', type=str, help='Path to options YAML file.', default='../options/debug.yml')
|
||||||
opt = option.parse(parser.parse_args().opt, is_train=False)
|
opt = option.parse(parser.parse_args().opt, is_train=False)
|
||||||
opt = option.dict_to_nonedict(opt)
|
opt = option.dict_to_nonedict(opt)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
writer = SummaryWriter("../experiments/train_div2k_feat_rg2_more_stuff")
|
writer = SummaryWriter("../experiments/train_div2k_feat_nsgen_r3/recovered_tb")
|
||||||
f = open("../experiments/train_div2k_feat_rg2_more_stuff/console_output")
|
f = open("../experiments/train_div2k_feat_nsgen_r3/console.txt", encoding="utf8")
|
||||||
console = f.readlines()
|
console = f.readlines()
|
||||||
search_terms = [
|
search_terms = [
|
||||||
("iter", ", iter: ", ", lr:"),
|
("iter", ", iter: ", ", lr:"),
|
||||||
|
|
|
@ -31,7 +31,7 @@ def init_dist(backend='nccl', **kwargs):
|
||||||
def main():
|
def main():
|
||||||
#### options
|
#### options
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_imgset_residualgenerator_fast_specific.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_div2k_feat_resgen2_lr.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
|
||||||
help='job launcher')
|
help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user