Push correct patch of recurrent embedding to upstream image, rather than whole thing

This commit is contained in:
James Betker 2020-10-18 22:39:52 -06:00
parent 7df378a944
commit 668cafa798

View File

@ -85,9 +85,12 @@ class ProgressiveGeneratorInjector(Injector):
recurrent_hq = base_hq_out
recurrent = base_recurrent
for link in chain: # Remember, `link` is a MultiscaleTreeNode.
top = int(link.top * h)
left = int(link.left * w)
recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest")
if self.feed_gen_output_into_input:
top = int(link.top * 2 * h)
left = int(link.left * 2 * w)
top *= 2
left *= 2
lq_input = recurrent_hq[:, :, top:top+h, left:left+w]
else:
lq_input = lq_inputs[:, link.index]