come on guys... :((
This commit is contained in:
parent
fcfb3a1525
commit
4a1f3aba31
|
@ -870,7 +870,7 @@ class GaussianDiffusion:
|
|||
s_err = (target - model_output) ** 2
|
||||
if channel_balancing_fn is not None:
|
||||
s_err = channel_balancing_fn(s_err)
|
||||
terms["mse_by_batch"] = s_err.view(s_err.shape[0], -1).mean(dim=1)
|
||||
terms["mse_by_batch"] = s_err.reshape(s_err.shape[0], -1).mean(dim=1)
|
||||
terms["mse"] = mean_flat(s_err)
|
||||
terms["x_start_predicted"] = x_start_pred
|
||||
if "vb" in terms:
|
||||
|
|
Loading…
Reference in New Issue
Block a user