forked from mrq/DL-Art-School
Fix error & add nonfinite warning
This commit is contained in:
parent
5d5558893a
commit
79367f753d
|
@ -3,7 +3,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
import torchvision
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ import torch
|
|||
import torchvision
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
from typing import Type, Any, Callable, Union, List, Optional, OrderedDict, Iterator
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision.models.resnet import BasicBlock, Bottleneck
|
||||
from torchvision.models.utils import load_state_dict_from_url
|
||||
import torchvision
|
||||
|
||||
|
||||
|
|
|
@ -227,6 +227,8 @@ class ConfigurableStep(Module):
|
|||
new_state.update(lstate)
|
||||
else:
|
||||
l = loss(self.get_network_for_name(self.step_opt['training']), local_state)
|
||||
if not l.isfinite():
|
||||
print(f'!!Detected non-finite loss {loss_name}')
|
||||
total_loss += l * self.weights[loss_name]
|
||||
# Record metrics.
|
||||
if isinstance(l, torch.Tensor):
|
||||
|
|
Loading…
Reference in New Issue
Block a user