diff --git a/codes/models/diffusion/gaussian_diffusion.py b/codes/models/diffusion/gaussian_diffusion.py index 472b210b..9c17e951 100644 --- a/codes/models/diffusion/gaussian_diffusion.py +++ b/codes/models/diffusion/gaussian_diffusion.py @@ -870,7 +870,7 @@ class GaussianDiffusion: s_err = (target - model_output) ** 2 if channel_balancing_fn is not None: s_err = channel_balancing_fn(s_err) - terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1) + terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1) terms["mse"] = mean_flat(s_err) terms["x_start_predicted"] = x_start_pred if "vb" in terms: