Fix error & add nonfinite warning

This commit is contained in:
James Betker 2021-11-09 23:58:41 -07:00
parent 5d5558893a
commit 79367f753d
4 changed files with 2 additions and 3 deletions

View File

@ -3,7 +3,6 @@
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
from torchvision.models.utils import load_state_dict_from_url
import torchvision

View File

@ -2,7 +2,6 @@ import torch
import torchvision
from torch import Tensor
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional, OrderedDict, Iterator
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',

View File

@ -3,7 +3,6 @@
import torch
import torch.nn as nn
from torchvision.models.resnet import BasicBlock, Bottleneck
from torchvision.models.utils import load_state_dict_from_url
import torchvision

View File

@ -227,6 +227,8 @@ class ConfigurableStep(Module):
new_state.update(lstate)
else:
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
if not l.isfinite():
print(f'!!Detected non-finite loss {loss_name}')
total_loss += l * self.weights[loss_name]
# Record metrics.
if isinstance(l, torch.Tensor):