diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 8ffde8bb..71adad15 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -98,7 +98,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): for exc in opt['exclusions']: with open(exc, 'r') as f: exclusions.extend(f.read().splitlines()) - self.audiopaths = load_paths_from_cache(path, cache_path, exclusions) + ew = opt_get(opt, ['endswith']) + self.audiopaths = load_paths_from_cache(path, cache_path, exclusions, ew) # Parse options self.sampling_rate = opt_get(opt, ['sampling_rate'], 22050) diff --git a/codes/data/util.py b/codes/data/util.py index 5b120cbb..03e6ad9c 100644 --- a/codes/data/util.py +++ b/codes/data/util.py @@ -578,7 +578,7 @@ def imresize_np(img, scale, antialiasing=True): return out_2.numpy() -def load_paths_from_cache(paths, cache_path, exclusion_list=[]): +def load_paths_from_cache(paths, cache_path, exclusion_list=[], endswith=None): if not isinstance(paths, list): paths = [paths] if os.path.exists(cache_path): @@ -595,6 +595,10 @@ def load_paths_from_cache(paths, cache_path, exclusion_list=[]): exclusion_set = set(exclusion_list) output = list(master_set - exclusion_set) print(f"Excluded {before-len(output)} files.") + if endswith is not None: + before = len(output) + output = list(filter(lambda p: not p.endswith(endswith), output)) + print(f"Excluded {before-len(output)} files with endswith mask. For total of {len(output)} files") print("Done.") torch.save(output, cache_path) return output diff --git a/codes/models/audio/mel2vec.py b/codes/models/audio/mel2vec.py index 95d7997e..cd984199 100644 --- a/codes/models/audio/mel2vec.py +++ b/codes/models/audio/mel2vec.py @@ -1,3 +1,5 @@ +import copy +import functools import math from typing import Optional, Tuple @@ -5,8 +7,12 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices from transformers.deepspeed import is_deepspeed_zero3_enabled +from trainer.networks import register_model +from utils.util import checkpoint + class Mel2Vec2FeatureProjection(nn.Module): def __init__(self, inner_dim, dropout): @@ -234,6 +240,29 @@ class Wav2Vec2SamePadLayer(nn.Module): return hidden_states +from torch.nn.utils.weight_norm import WeightNorm +def __deepcopy__(self, memo): + # save and delete all weightnorm weights on self + weights = {} + for hook in self._forward_pre_hooks.values(): + if isinstance(hook, WeightNorm): + weights[hook.name] = getattr(self, hook.name) + delattr(self, hook.name) + # remove this deepcopy method, restoring the object's original one if necessary + __deepcopy__ = self.__deepcopy__ + if self.orig_deepcopy: + self.__deepcopy__ = self.orig_deepcopy + else: + del self.__deepcopy__ + # actually do the copy + result = copy.deepcopy(self) + # restore weights and method on self + for name, value in weights.items(): + setattr(self, name, value) + self.__deepcopy__ = __deepcopy__ + return result + + class Wav2Vec2PositionalConvEmbedding(nn.Module): def __init__(self, hidden_size, num_conv_pos_embeddings=128, num_conv_pos_embedding_groups=16): super().__init__() @@ -244,6 +273,9 @@ class Wav2Vec2PositionalConvEmbedding(nn.Module): padding=num_conv_pos_embeddings // 2, groups=num_conv_pos_embedding_groups, ) + # Fix weightnorm deepcopy; see: https://github.com/pytorch/pytorch/issues/28594 + self.conv.orig_deepcopy = getattr(Wav2Vec2PositionalConvEmbedding, '__deepcopy__', None) + self.conv.__deepcopy__ = __deepcopy__.__get__(self.conv, self.conv.__class__) if is_deepspeed_zero3_enabled(): import deepspeed @@ -276,7 +308,6 @@ class Wav2Vec2Encoder(nn.Module): self.layer_norm = nn.LayerNorm(hidden_size, eps=1e-5) self.dropout = nn.Dropout(dropout) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(hidden_size, dropout) for _ in range(num_layers)]) - self.gradient_checkpointing = False self.layerdrop = layerdrop def forward( @@ -314,24 +345,8 @@ class Wav2Vec2Encoder(nn.Module): skip_the_layer = True if self.training and (dropout_probability < self.layerdrop) else False if not skip_the_layer or deepspeed_zero3_is_enabled: - # under deepspeed zero3 all gpus must run in sync - if self.gradient_checkpointing and self.training: - # create gradient checkpointing function - def create_custom_forward(module): - def custom_forward(*inputs): - return module(*inputs, output_attentions) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer), - hidden_states, - attention_mask, - ) - else: - layer_outputs = layer( - hidden_states, attention_mask=attention_mask, output_attentions=output_attentions - ) + layer_fn = functools.partial(layer, attention_mask=attention_mask) + layer_outputs = checkpoint(layer_fn, hidden_states) hidden_states = layer_outputs[0] return hidden_states @@ -345,17 +360,21 @@ class Mel2Vec(nn.Module): dropout=.1, layerdrop=0, mask_time_prob=.65, - mask_time_length=10): + mask_time_length=10, + ): + super().__init__() self.input_blocks = nn.Sequential(nn.Conv1d(mel_input_channels, inner_dim//2, kernel_size=5, padding=2, stride=2), - nn.GroupNorm(num_groups=8, num_channels=inner_dim, affine=True), + nn.GroupNorm(num_groups=8, num_channels=inner_dim//2, affine=True), nn.SiLU(), nn.Conv1d(inner_dim//2, inner_dim, kernel_size=3, padding=1, stride=2), nn.GroupNorm(num_groups=8, num_channels=inner_dim, affine=True), nn.SiLU(), ) - self.projector = Wav2Vec2FeatureProjection(inner_dim, dropout) + self.projector = Mel2Vec2FeatureProjection(inner_dim, dropout) self.masked_spec_embed = nn.Parameter(torch.rand(inner_dim,)) self.encoder = Wav2Vec2Encoder(inner_dim, dropout, layers, layerdrop) + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length self.apply(self.init) def init(self, module): @@ -368,12 +387,12 @@ class Mel2Vec(nn.Module): std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), ) nn.init.constant_(module.conv.bias, 0) - elif isinstance(module, Wav2Vec2FeatureProjection): + elif isinstance(module, Mel2Vec2FeatureProjection): k = math.sqrt(1 / module.projection.in_features) nn.init.uniform_(module.projection.weight, a=-k, b=k) nn.init.uniform_(module.projection.bias, a=-k, b=k) elif isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + module.weight.data.normal_(mean=0.0, std=.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): @@ -396,48 +415,35 @@ class Mel2Vec(nn.Module): [SpecAugment](https://arxiv.org/abs/1904.08779). """ - # `config.apply_spec_augment` can set masking to False - if not getattr(self.config, "apply_spec_augment", True): - return hidden_states - # generate indices & apply SpecAugment along time axis batch_size, sequence_length, hidden_size = hidden_states.size() if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) - elif self.config.mask_time_prob > 0 and self.training: + elif self.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), - mask_prob=self.config.mask_time_prob, - mask_length=self.config.mask_time_length, + mask_prob=self.mask_time_prob, + mask_length=self.mask_time_length, attention_mask=attention_mask, - min_masks=self.config.mask_time_min_masks, + min_masks=self.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) - if self.config.mask_feature_prob > 0 and self.training: - # generate indices & apply SpecAugment along feature axis - mask_feature_indices = _compute_mask_indices( - (batch_size, hidden_size), - mask_prob=self.config.mask_feature_prob, - mask_length=self.config.mask_feature_length, - min_masks=self.config.mask_feature_min_masks, - ) - mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) - mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) - hidden_states[mask_feature_indices] = 0 - return hidden_states - def forward(self, mel): + def forward(self, mel, mask_time_indices=None, return_projections=False): proj = self.input_blocks(mel).permute(0,2,1) proj, _ = self.projector(proj) # Mask projections h = self.apply_masking(proj, mask_time_indices) h = self.encoder(h) + + if return_projections: + return h, proj return h @@ -452,11 +458,12 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): self.codevector_dim = codevector_dim self.num_groups = num_codevector_groups self.num_vars = num_codevectors_per_group + self.num_codevectors = num_codevector_groups * num_codevectors_per_group if codevector_dim % self.num_groups != 0: raise ValueError( - f"`config.codevector_dim {config.codevector_dim} must be divisible " - f"by `config.num_codevector_groups` {self.num_groups} for concatenation" + f"`codevector_dim {codevector_dim} must be divisible " + f"by `num_codevector_groups` {num_codevector_groups} for concatenation" ) # storage for codebook variables (codewords) @@ -527,16 +534,112 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): class ContrastiveTrainingWrapper(nn.Module): - def __init__(self, **kwargs): + def __init__(self, inner_dim=1024, dropout=.1, mask_time_prob=.65, mask_time_length=4, num_negatives=100, **kwargs): super().__init__() - self.m2v = Mel2Vec(**kwargs) - self.dropout_features = nn.Dropout(kwargs['dropout']) + self.m2v = Mel2Vec(inner_dim=inner_dim, dropout=dropout, mask_time_prob=mask_time_prob, + mask_time_length=mask_time_length, **kwargs) + self.dropout_features = nn.Dropout(dropout) + self.num_negatives = num_negatives + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length - self.quantizer = Wav2Vec2GumbelVectorQuantizer(kwargs['inner_dim']) + self.quantizer = Wav2Vec2GumbelVectorQuantizer(inner_dim) # make sure that project_hid & project_q are initialized like normal linear layers - self.project_hid = nn.Linear(kwargs['inner_dim'], self.quantizer.codevector_dim) + self.project_hid = nn.Linear(inner_dim, self.quantizer.codevector_dim) self.project_q = nn.Linear(self.quantizer.codevector_dim, self.quantizer.codevector_dim) + @staticmethod + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 0.1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as( + target_features + ) + + # apply temperature + logits = logits / temperature + return logits + def forward(self, mel): - pass \ No newline at end of file + features_shape = (mel.shape[0], mel.shape[-1]//4) + mask_time_indices = _compute_mask_indices(features_shape, self.mask_time_prob, self.mask_time_length) + sampled_negative_indices = torch.tensor(_sample_negative_indices(features_shape, self.num_negatives, mask_time_indices=mask_time_indices), device=mel.device) + mask_time_indices = torch.tensor(mask_time_indices, device=mel.device) + + outputs, proj = self.m2v(mel, mask_time_indices, return_projections=True) + + # 1. project all transformed features (including masked) to final vq dim + transformer_features = self.project_hid(outputs) + + # 2. quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(proj) + + quantized_features, codevector_perplexity = self.quantizer( + extract_features, mask_time_indices=mask_time_indices + ) + quantized_features = self.project_q(quantized_features) + batch_size, sequence_length, hidden_size = quantized_features.shape + + # 3. sample K negatives (distractors) quantized states for contrastive loss + # if attention_mask is passed, make sure that padded feature vectors cannot be sampled + # sample negative quantized vectors BTC => (BxT)C + negative_quantized_features = quantized_features.view(-1, hidden_size)[ + sampled_negative_indices.long().view(-1) + ] + negative_quantized_features = negative_quantized_features.view( + batch_size, sequence_length, -1, hidden_size + ).permute(2, 0, 1, 3) + + # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` + # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf + logits = self.compute_contrastive_logits( + quantized_features[None, :], + negative_quantized_features, + transformer_features, + .1, + ) + + # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), + # its cosine similarity will be masked + neg_is_pos = (quantized_features == negative_quantized_features).all(-1) + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = + # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) + logits = logits.transpose(0, 2).reshape(-1, logits.size(0)) + target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() + + contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum") + # 7. compute diversity loss: \mathbf{L}_d + num_codevectors = self.quantizer.num_codevectors + diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() + + return contrastive_loss, diversity_loss + + +@register_model +def register_mel2vec_pretraining(opt_net, opt): + return ContrastiveTrainingWrapper(**opt_net['kwargs']) + + +@register_model +def register_mel2vec(opt_net, opt): + return Mel2Vec(**opt_net['kwargs']) + + +if __name__ == '__main__': + model = ContrastiveTrainingWrapper() + mel = torch.randn((2,256,400)) + print(model(mel)) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index 14039ec0..721f6a09 100644 --- a/codes/train.py +++ b/codes/train.py @@ -327,7 +327,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_waveform_gen3.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_music_mel2vec.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') args = parser.parse_args() opt = option.parse(args.opt, is_train=True)