torchscale/torchscale/component/droppath.py
2022-11-26 09:01:02 -08:00

20 lines
558 B
Python

# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
import torch.nn as nn
from timm.models.layers import drop_path
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self):
return "p={}".format(self.drop_prob)