Add voice2voice clip model

This commit is contained in:
James Betker 2021-12-28 16:18:12 -07:00
parent a9ee5b624f
commit 07c2b9907c
4 changed files with 92 additions and 4 deletions

View File

@ -69,6 +69,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
if self.pad_to is not None:
self.pad_to *= self.sampling_rate
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
self.min_length = opt_get(opt, ['min_length'], 0)
# "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no
# guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples,
@ -79,9 +80,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2)
self.extra_sample_len *= self.sampling_rate
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
def get_audio_for_index(self, index):
audiopath = self.audiopaths[index]
audio = load_audio(audiopath, self.sampling_rate)
assert audio.shape[1] > self.min_length
return audio, audiopath
def get_related_audio_for_index(self, index):
@ -121,7 +125,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
audio_norm, filename = self.get_audio_for_index(index)
alt_files, actual_samples = self.get_related_audio_for_index(index)
except:
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
if self.debug_loading_failures:
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
return self[index+1]
# When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their

View File

@ -0,0 +1,83 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import einsum
from models.gpt_voice.mini_encoder import AudioMiniEncoder
from models.lucidrains.dalle.transformer import Transformer
from trainer.networks import register_model
from utils.util import opt_get
def exists(val):
return val is not None
def masked_mean(t, mask, dim=1):
t = t.masked_fill(~mask[:, :, None], 0.)
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
class VoiceCLIP(nn.Module):
"""
CLIP model modified to produce similarity scores from different views of the same audio clip.
"""
def __init__(
self,
encoder_output=512,
dim_latent=512,
speech_max_seq_len=250,
mel_compression_ratio=256,
pretrained_encoder_dict_path=None
):
super().__init__()
self.encoder = AudioMiniEncoder(80, encoder_output)
if pretrained_encoder_dict_path is not None:
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False)
self.temperature = nn.Parameter(torch.tensor(1.))
self.mel_compression_ratio = mel_compression_ratio
def forward(
self,
speech_mels,
speech_lengths,
return_loss=True
):
half_length = min(speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2
half_length = (half_length // 4) * 4 # Must be a multiple of 4.
first_half = speech_mels[:, :, :half_length]
second_half = speech_mels[:, :, half_length:half_length*2]
first_emb = self.encoder(first_half)
first_latents = self.to_latent(first_emb)
second_emb = self.encoder(second_half)
second_latents = self.to_latent(second_emb)
first_latents, second_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (first_latents, second_latents))
temp = self.temperature.exp()
if not return_loss:
sim = einsum('n d, n d -> n', first_latents, second_latents) * temp
return sim
sim = einsum('i d, j d -> i j', first_latents, second_latents) * temp
labels = torch.arange(first_latents.shape[0], device=first_latents.device)
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
return loss
@register_model
def register_voice_to_voice_clip(opt_net, opt):
return VoiceCLIP(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
clip = VoiceCLIP()
clip(torch.randn((2,80,200)),
torch.randint(0,200*1024,(2,)),
return_loss=True)

View File

@ -257,14 +257,14 @@ class Trainer:
import wandb
wandb.log(eval_dict)
def do_training(self):
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
for epoch in range(self.start_epoch, self.total_epochs + 1):
self.epoch = epoch
if opt['dist']:
self.train_sampler.set_epoch(epoch)
tq_ldr = tqdm(self.train_loader) if self.rank == 0 else self.train_loader
tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader
_t = time()
for train_data in tq_ldr:
@ -286,7 +286,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_voice_voice_clip.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()