come on guys... :((

This commit is contained in:
James Betker 2022-06-21 20:12:54 -06:00
parent fcfb3a1525
commit 4a1f3aba31

View File

@ -870,7 +870,7 @@ class GaussianDiffusion:
s_err = (target - model_output) ** 2 s_err = (target - model_output) ** 2
if channel_balancing_fn is not None: if channel_balancing_fn is not None:
s_err = channel_balancing_fn(s_err) s_err = channel_balancing_fn(s_err)
terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1) terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
terms["mse"] = mean_flat(s_err) terms["mse"] = mean_flat(s_err)
terms["x_start_predicted"] = x_start_pred terms["x_start_predicted"] = x_start_pred
if "vb" in terms: if "vb" in terms: