come on guys... :((

This commit is contained in:
James Betker 2022-06-21 20:12:54 -06:00
parent fcfb3a1525
commit 4a1f3aba31

View File

@ -870,7 +870,7 @@ class GaussianDiffusion:
s_err = (target - model_output) ** 2
if channel_balancing_fn is not None:
s_err = channel_balancing_fn(s_err)
terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1)
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
terms["mse"] = mean_flat(s_err)
terms["x_start_predicted"] = x_start_pred
if "vb" in terms: