From 5037220ac7640edf779639b56c20ddc59029075b Mon Sep 17 00:00:00 2001 From: James Betker Date: Thu, 5 Aug 2021 05:57:04 -0600 Subject: [PATCH] Mods to support contrastive learning on audio files --- codes/data/__init__.py | 2 + codes/data/audio/wavfile_dataset.py | 81 ++++++ codes/data/util.py | 17 +- codes/models/audio_resnet.py | 387 ++++++++++++++++++++++++++++ codes/train.py | 2 +- codes/trainer/lr_scheduler.py | 4 +- 6 files changed, 484 insertions(+), 9 deletions(-) create mode 100644 codes/data/audio/wavfile_dataset.py create mode 100644 codes/models/audio_resnet.py diff --git a/codes/data/__init__.py b/codes/data/__init__.py index 4130c525..ea4bb4e0 100644 --- a/codes/data/__init__.py +++ b/codes/data/__init__.py @@ -74,6 +74,8 @@ def create_dataset(dataset_opt, return_collate=False): from data.audio.gpt_tts_dataset import GptTtsDataset as D from data.audio.gpt_tts_dataset import GptTtsCollater as C collate = C(dataset_opt) + elif mode == 'wavfile_clips': + from data.audio.wavfile_dataset import WavfileDataset as D else: raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) dataset = D(dataset_opt) diff --git a/codes/data/audio/wavfile_dataset.py b/codes/data/audio/wavfile_dataset.py new file mode 100644 index 00000000..2be61a0b --- /dev/null +++ b/codes/data/audio/wavfile_dataset.py @@ -0,0 +1,81 @@ +import os +import random + +import torch +import torch.utils.data +from tqdm import tqdm + +from data.util import get_image_paths, is_wav_file +from models.tacotron2.taco_utils import load_wav_to_torch + + +class WavfileDataset(torch.utils.data.Dataset): + + def __init__(self, opt): + self.path = os.path.dirname(opt['path']) + cache_path = os.path.join(self.path, 'cache.pth') + if os.path.exists(cache_path): + self.audiopaths = torch.load(cache_path) + else: + print("Building cache..") + self.audiopaths = get_image_paths('img', opt['path'], qualifier=is_wav_file)[0] + torch.save(self.audiopaths, cache_path) + self.max_wav_value = 32768.0 + self.sampling_rate = 24000 + self.window = 2 * self.sampling_rate + + def get_audio_for_index(self, index): + audiopath = self.audiopaths[index] + filename = os.path.join(self.path, audiopath) + audio, sampling_rate = load_wav_to_torch(filename) + if sampling_rate != self.sampling_rate: + raise ValueError(f"Input sampling rate does not match specified rate {self.sampling_rate}") + audio_norm = audio / self.max_wav_value + audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False) + return audio_norm, audiopath + + def __getitem__(self, index): + clip1, clip2 = None, None + + while clip1 is None and clip2 is None: + # Split audio_norm into two tensors of equal size. + audio_norm, filename = self.get_audio_for_index(index) + if audio_norm.shape[0] < self.window * 2: + # Try next index. This adds a bit of bias and ideally we'd filter the dataset rather than do this. + index = (index + 1) % len(self) + continue + j = random.randint(0, audio_norm.shape[0] - self.window) + clip1 = audio_norm[j:j+self.window] + j = random.randint(0, audio_norm.shape[0]-self.window) + clip2 = audio_norm[j:j+self.window] + + return { + 'clip1': clip1.unsqueeze(0), + 'clip2': clip2.unsqueeze(0), + 'path': filename, + } + + def __len__(self): + return len(self.audiopaths) + + +if __name__ == '__main__': + params = { + 'mode': 'wavfile_clips', + 'path': 'E:\\audio\\LibriTTS\\train-other-500', + 'phase': 'train', + 'n_workers': 0, + 'batch_size': 16, + } + from data import create_dataset, create_dataloader, util + + ds, c = create_dataset(params, return_collate=True) + dl = create_dataloader(ds, params, collate_fn=c) + i = 0 + m = [] + max_text = 0 + max_mel = 0 + for b in tqdm(dl): + pass + m=torch.stack(m) + print(m.mean(), m.std()) diff --git a/codes/data/util.py b/codes/data/util.py index 17b1690c..23b072bb 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -39,14 +39,17 @@ def cv2torch(cv, batchify=True): def is_image_file(filename): return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) +def is_wav_file(filename): + return filename.endswith('.wav') -def _get_paths_from_images(path): + +def _get_paths_from_images(path, qualifier=is_image_file): """get image path list from image folder""" assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) images = [] for dirpath, _, fnames in sorted(os.walk(path)): for fname in sorted(fnames): - if is_image_file(fname) and 'ref.jpg' not in fname: + if qualifier(fname) and 'ref.jpg' not in fname: img_path = os.path.join(dirpath, fname) images.append(img_path) if not images: @@ -64,7 +67,7 @@ def _get_paths_from_lmdb(dataroot): return paths, sizes -def get_image_paths(data_type, dataroot, weights=[]): +def get_image_paths(data_type, dataroot, weights=[], qualifier=is_image_file): """get image path list support lmdb or image files""" paths, sizes = None, None @@ -82,11 +85,11 @@ def get_image_paths(data_type, dataroot, weights=[]): if weights: extends = weights[i] for j in range(extends): - paths.extend(_get_paths_from_images(r)) + paths.extend(_get_paths_from_images(r, qualifier)) paths = sorted(paths) sizes = len(paths) else: - paths = sorted(_get_paths_from_images(dataroot)) + paths = sorted(_get_paths_from_images(dataroot, qualifier)) sizes = len(paths) else: raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) @@ -117,9 +120,9 @@ def read_img(env, path, size=None, rgb=False): stream = open(path, "rb") bytes = bytearray(stream.read()) img = cv2.imdecode(np.asarray(bytes, dtype=np.uint8), cv2.IMREAD_UNCHANGED) - elif env is 'lmdb': + elif env == 'lmdb': img = _read_img_lmdb(env, path, size) - elif env is 'buffer': + elif env == 'buffer': img = cv2.imdecode(path, cv2.IMREAD_UNCHANGED) else: raise NotImplementedError("Unsupported env: %s" % (env,)) diff --git a/codes/models/audio_resnet.py b/codes/models/audio_resnet.py new file mode 100644 index 00000000..46694f21 --- /dev/null +++ b/codes/models/audio_resnet.py @@ -0,0 +1,387 @@ +import torch +from torch import Tensor +import torch.nn as nn + +from trainer.networks import register_model +from utils.util import opt_get +from typing import Type, Any, Callable, Union, List, Optional + + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', + 'wide_resnet50_2', 'wide_resnet101_2'] + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv1d: + """3x3 convolution with padding""" + return nn.Conv1d(in_planes, out_planes, kernel_size=5, stride=stride, + padding=2, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv1d: + """1x1 convolution""" + return nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm1d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm1d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm1d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv1d(1, self.inplanes, kernel_size=7, stride=4, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool1d(kernel_size=5, stride=4, padding=2) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=4, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=4, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=4, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm1d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, + stride: int = 1, dilate: bool = False) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.fc(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + arch: str, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + pretrained: bool, + progress: bool, + **kwargs: Any +) -> ResNet: + model = ResNet(block, layers, **kwargs) + return model + + +def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-18 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, + **kwargs) + + +def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-34 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-50 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, + **kwargs) + + +def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-101 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, + **kwargs) + + +def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNet-152 model from + `"Deep Residual Learning for Image Recognition" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, + **kwargs) + + +def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-50 32x4d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 4 + return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""ResNeXt-101 32x8d model from + `"Aggregated Residual Transformation for Deep Neural Networks" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-50-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], + pretrained, progress, **kwargs) + + +def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: + r"""Wide ResNet-101-2 model from + `"Wide Residual Networks" `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ + kwargs['width_per_group'] = 64 * 2 + return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], + pretrained, progress, **kwargs) + + +@register_model +def register_audio_resnet(opt_net, opt): + type = opt_net['type'] + fn = globals()[type] + return fn(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + m = resnet34() + o = m(torch.randn((1,1,48000))) + print(o.shape) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index da87349b..03513ad2 100644 --- a/codes/train.py +++ b/codes/train.py @@ -300,7 +300,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_tts_lj.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_byol_audio_clips.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args() diff --git a/codes/trainer/lr_scheduler.py b/codes/trainer/lr_scheduler.py index c58e294f..53051a2c 100644 --- a/codes/trainer/lr_scheduler.py +++ b/codes/trainer/lr_scheduler.py @@ -4,6 +4,8 @@ from collections import defaultdict import torch from torch.optim.lr_scheduler import _LRScheduler +from utils.util import opt_get + def get_scheduler_for_name(name, optimizers, scheduler_opt): schedulers = [] @@ -19,7 +21,7 @@ def get_scheduler_for_name(name, optimizers, scheduler_opt): gamma=scheduler_opt['lr_gamma'], clear_state=scheduler_opt['clear_state'], force_lr=scheduler_opt['force_lr'], - warmup_steps=scheduler_opt['warmup_steps']) + warmup_steps=opt_get(scheduler_opt, ['warmup_steps'], 0)) elif name == 'ProgressiveMultiStepLR': sched = ProgressiveMultiStepLR(o, scheduler_opt['gen_lr_steps'], scheduler_opt['progressive_starts'],