Mods to support contrastive learning on audio files
This commit is contained in:
parent
341f28dd82
commit
5037220ac7
|
@ -74,6 +74,8 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
from data.audio.gpt_tts_dataset import GptTtsDataset as D
|
||||
from data.audio.gpt_tts_dataset import GptTtsCollater as C
|
||||
collate = C(dataset_opt)
|
||||
elif mode == 'wavfile_clips':
|
||||
from data.audio.wavfile_dataset import WavfileDataset as D
|
||||
else:
|
||||
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
|
||||
dataset = D(dataset_opt)
|
||||
|
|
81
codes/data/audio/wavfile_dataset.py
Normal file
81
codes/data/audio/wavfile_dataset.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.util import get_image_paths, is_wav_file
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
|
||||
|
||||
class WavfileDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, opt):
|
||||
self.path = os.path.dirname(opt['path'])
|
||||
cache_path = os.path.join(self.path, 'cache.pth')
|
||||
if os.path.exists(cache_path):
|
||||
self.audiopaths = torch.load(cache_path)
|
||||
else:
|
||||
print("Building cache..")
|
||||
self.audiopaths = get_image_paths('img', opt['path'], qualifier=is_wav_file)[0]
|
||||
torch.save(self.audiopaths, cache_path)
|
||||
self.max_wav_value = 32768.0
|
||||
self.sampling_rate = 24000
|
||||
self.window = 2 * self.sampling_rate
|
||||
|
||||
def get_audio_for_index(self, index):
|
||||
audiopath = self.audiopaths[index]
|
||||
filename = os.path.join(self.path, audiopath)
|
||||
audio, sampling_rate = load_wav_to_torch(filename)
|
||||
if sampling_rate != self.sampling_rate:
|
||||
raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}")
|
||||
audio_norm = audio / self.max_wav_value
|
||||
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
||||
return audio_norm, audiopath
|
||||
|
||||
def __getitem__(self, index):
|
||||
clip1, clip2 = None, None
|
||||
|
||||
while clip1 is None and clip2 is None:
|
||||
# Split audio_norm into two tensors of equal size.
|
||||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
if audio_norm.shape[0] < self.window * 2:
|
||||
# Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this.
|
||||
index = (index + 1) % len(self)
|
||||
continue
|
||||
j = random.randint(0, audio_norm.shape[0] - self.window)
|
||||
clip1 = audio_norm[j:j+self.window]
|
||||
j = random.randint(0, audio_norm.shape[0]-self.window)
|
||||
clip2 = audio_norm[j:j+self.window]
|
||||
|
||||
return {
|
||||
'clip1': clip1.unsqueeze(0),
|
||||
'clip2': clip2.unsqueeze(0),
|
||||
'path': filename,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
params = {
|
||||
'mode': 'wavfile_clips',
|
||||
'path': 'E:\\audio\\LibriTTS\\train-other-500',
|
||||
'phase': 'train',
|
||||
'n_workers': 0,
|
||||
'batch_size': 16,
|
||||
}
|
||||
from data import create_dataset, create_dataloader, util
|
||||
|
||||
ds, c = create_dataset(params, return_collate=True)
|
||||
dl = create_dataloader(ds, params, collate_fn=c)
|
||||
i = 0
|
||||
m = []
|
||||
max_text = 0
|
||||
max_mel = 0
|
||||
for b in tqdm(dl):
|
||||
pass
|
||||
m=torch.stack(m)
|
||||
print(m.mean(), m.std())
|
|
@ -39,14 +39,17 @@ def cv2torch(cv, batchify=True):
|
|||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
def is_wav_file(filename):
|
||||
return filename.endswith('.wav')
|
||||
|
||||
def _get_paths_from_images(path):
|
||||
|
||||
def _get_paths_from_images(path, qualifier=is_image_file):
|
||||
"""get image path list from image folder"""
|
||||
assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
|
||||
images = []
|
||||
for dirpath, _, fnames in sorted(os.walk(path)):
|
||||
for fname in sorted(fnames):
|
||||
if is_image_file(fname) and 'ref.jpg' not in fname:
|
||||
if qualifier(fname) and 'ref.jpg' not in fname:
|
||||
img_path = os.path.join(dirpath, fname)
|
||||
images.append(img_path)
|
||||
if not images:
|
||||
|
@ -64,7 +67,7 @@ def _get_paths_from_lmdb(dataroot):
|
|||
return paths, sizes
|
||||
|
||||
|
||||
def get_image_paths(data_type, dataroot, weights=[]):
|
||||
def get_image_paths(data_type, dataroot, weights=[], qualifier=is_image_file):
|
||||
"""get image path list
|
||||
support lmdb or image files"""
|
||||
paths, sizes = None, None
|
||||
|
@ -82,11 +85,11 @@ def get_image_paths(data_type, dataroot, weights=[]):
|
|||
if weights:
|
||||
extends = weights[i]
|
||||
for j in range(extends):
|
||||
paths.extend(_get_paths_from_images(r))
|
||||
paths.extend(_get_paths_from_images(r, qualifier))
|
||||
paths = sorted(paths)
|
||||
sizes = len(paths)
|
||||
else:
|
||||
paths = sorted(_get_paths_from_images(dataroot))
|
||||
paths = sorted(_get_paths_from_images(dataroot, qualifier))
|
||||
sizes = len(paths)
|
||||
else:
|
||||
raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
|
||||
|
@ -117,9 +120,9 @@ def read_img(env, path, size=None, rgb=False):
|
|||
stream = open(path, "rb")
|
||||
bytes = bytearray(stream.read())
|
||||
img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
|
||||
elif env is 'lmdb':
|
||||
elif env == 'lmdb':
|
||||
img = _read_img_lmdb(env, path, size)
|
||||
elif env is 'buffer':
|
||||
elif env == 'buffer':
|
||||
img = cv2.imdecode(path, cv2.IMREAD_UNCHANGED)
|
||||
else:
|
||||
raise NotImplementedError("Unsupported env: %s" % (env,))
|
||||
|
|
387
codes/models/audio_resnet.py
Normal file
387
codes/models/audio_resnet.py
Normal file
|
@ -0,0 +1,387 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
from typing import Type, Any, Callable, Union, List, Optional
|
||||
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
||||
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
|
||||
'wide_resnet50_2', 'wide_resnet101_2']
|
||||
|
||||
|
||||
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv1d:
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv1d(in_planes, out_planes, kernel_size=5, stride=stride,
|
||||
padding=2, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d:
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(BasicBlock, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm1d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
||||
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
||||
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
||||
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
||||
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
||||
|
||||
expansion: int = 4
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inplanes: int,
|
||||
planes: int,
|
||||
stride: int = 1,
|
||||
downsample: Optional[nn.Module] = None,
|
||||
groups: int = 1,
|
||||
base_width: int = 64,
|
||||
dilation: int = 1,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(Bottleneck, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm1d
|
||||
width = int(planes * (base_width / 64.)) * groups
|
||||
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
||||
self.conv1 = conv1x1(inplanes, width)
|
||||
self.bn1 = norm_layer(width)
|
||||
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
||||
self.bn2 = norm_layer(width)
|
||||
self.conv3 = conv1x1(width, planes * self.expansion)
|
||||
self.bn3 = norm_layer(planes * self.expansion)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
identity = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
num_classes: int = 1000,
|
||||
zero_init_residual: bool = False,
|
||||
groups: int = 1,
|
||||
width_per_group: int = 64,
|
||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None
|
||||
) -> None:
|
||||
super(ResNet, self).__init__()
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm1d
|
||||
self._norm_layer = norm_layer
|
||||
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
# each element in the tuple indicates if we should replace
|
||||
# the 2x2 stride with a dilated convolution instead
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv1d(1, self.inplanes, kernel_size=7, stride=4, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = norm_layer(self.inplanes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool1d(kernel_size=5, stride=4, padding=2)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=4,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=4,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=4,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.avgpool = nn.AdaptiveAvgPool1d(1)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
# Zero-initialize the last BN in each residual branch,
|
||||
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
||||
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck):
|
||||
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
||||
elif isinstance(m, BasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
||||
|
||||
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
|
||||
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
||||
norm_layer = self._norm_layer
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
norm_layer(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation, norm_layer))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups=self.groups,
|
||||
base_width=self.base_width, dilation=self.dilation,
|
||||
norm_layer=norm_layer))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _forward_impl(self, x: Tensor) -> Tensor:
|
||||
# See note [TorchScript super()]
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self._forward_impl(x)
|
||||
|
||||
|
||||
def _resnet(
|
||||
arch: str,
|
||||
block: Type[Union[BasicBlock, Bottleneck]],
|
||||
layers: List[int],
|
||||
pretrained: bool,
|
||||
progress: bool,
|
||||
**kwargs: Any
|
||||
) -> ResNet:
|
||||
model = ResNet(block, layers, **kwargs)
|
||||
return model
|
||||
|
||||
|
||||
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-18 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-34 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-50 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-101 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNet-152 model from
|
||||
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-50 32x4d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 4
|
||||
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""ResNeXt-101 32x8d model from
|
||||
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['groups'] = 32
|
||||
kwargs['width_per_group'] = 8
|
||||
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-50-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
|
||||
r"""Wide ResNet-101-2 model from
|
||||
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
|
||||
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
progress (bool): If True, displays a progress bar of the download to stderr
|
||||
"""
|
||||
kwargs['width_per_group'] = 64 * 2
|
||||
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
|
||||
pretrained, progress, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def register_audio_resnet(opt_net, opt):
|
||||
type = opt_net['type']
|
||||
fn = globals()[type]
|
||||
return fn(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
m = resnet34()
|
||||
o = m(torch.randn((1,1,48000)))
|
||||
print(o.shape)
|
|
@ -300,7 +300,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
|
@ -4,6 +4,8 @@ from collections import defaultdict
|
|||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
||||
schedulers = []
|
||||
|
@ -19,7 +21,7 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt):
|
|||
gamma=scheduler_opt['lr_gamma'],
|
||||
clear_state=scheduler_opt['clear_state'],
|
||||
force_lr=scheduler_opt['force_lr'],
|
||||
warmup_steps=scheduler_opt['warmup_steps'])
|
||||
warmup_steps=opt_get(scheduler_opt, ['warmup_steps'], 0))
|
||||
elif name == 'ProgressiveMultiStepLR':
|
||||
sched = ProgressiveMultiStepLR(o, scheduler_opt['gen_lr_steps'],
|
||||
scheduler_opt['progressive_starts'],
|
||||
|
|
Loading…
Reference in New Issue
Block a user