diff --git a/codes/models/steps/progressive_zoom.py b/codes/models/steps/progressive_zoom.py index f7661af8..c516e0bb 100644 --- a/codes/models/steps/progressive_zoom.py +++ b/codes/models/steps/progressive_zoom.py @@ -85,9 +85,12 @@ class ProgressiveGeneratorInjector(Injector): recurrent_hq = base_hq_out recurrent = base_recurrent for link in chain: # Remember, `link` is a MultiscaleTreeNode. + top = int(link.top * h) + left = int(link.left * w) + recurrent = torch.nn.functional.interpolate(recurrent[:, :, top:top+h//2, left:left+w//2], scale_factor=2, mode="nearest") if self.feed_gen_output_into_input: - top = int(link.top * 2 * h) - left = int(link.left * 2 * w) + top *= 2 + left *= 2 lq_input = recurrent_hq[:, :, top:top+h, left:left+w] else: lq_input = lq_inputs[:, link.index]