Fix error & add nonfinite warning
This commit is contained in:
parent
5d5558893a
commit
79367f753d
|
@ -3,7 +3,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||||
from torchvision.models.utils import load_state_dict_from_url
|
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.models.utils import load_state_dict_from_url
|
|
||||||
from typing import Type, Any, Callable, Union, List, Optional, OrderedDict, Iterator
|
from typing import Type, Any, Callable, Union, List, Optional, OrderedDict, Iterator
|
||||||
|
|
||||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||||
from torchvision.models.utils import load_state_dict_from_url
|
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -227,6 +227,8 @@ class ConfigurableStep(Module):
|
||||||
new_state.update(lstate)
|
new_state.update(lstate)
|
||||||
else:
|
else:
|
||||||
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||||
|
if not l.isfinite():
|
||||||
|
print(f'!!Detected non-finite loss {loss_name}')
|
||||||
total_loss += l * self.weights[loss_name]
|
total_loss += l * self.weights[loss_name]
|
||||||
# Record metrics.
|
# Record metrics.
|
||||||
if isinstance(l, torch.Tensor):
|
if isinstance(l, torch.Tensor):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user