forked from mrq/DL-Art-School
Add voice2voice clip model
This commit is contained in:
parent
a9ee5b624f
commit
07c2b9907c
|
@ -69,6 +69,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
if self.pad_to is not None:
|
if self.pad_to is not None:
|
||||||
self.pad_to *= self.sampling_rate
|
self.pad_to *= self.sampling_rate
|
||||||
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
|
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
|
||||||
|
self.min_length = opt_get(opt, ['min_length'], 0)
|
||||||
|
|
||||||
# "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no
|
# "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no
|
||||||
# guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples,
|
# guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples,
|
||||||
|
@ -79,9 +80,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2)
|
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2)
|
||||||
self.extra_sample_len *= self.sampling_rate
|
self.extra_sample_len *= self.sampling_rate
|
||||||
|
|
||||||
|
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
|
||||||
|
|
||||||
def get_audio_for_index(self, index):
|
def get_audio_for_index(self, index):
|
||||||
audiopath = self.audiopaths[index]
|
audiopath = self.audiopaths[index]
|
||||||
audio = load_audio(audiopath, self.sampling_rate)
|
audio = load_audio(audiopath, self.sampling_rate)
|
||||||
|
assert audio.shape[1] > self.min_length
|
||||||
return audio, audiopath
|
return audio, audiopath
|
||||||
|
|
||||||
def get_related_audio_for_index(self, index):
|
def get_related_audio_for_index(self, index):
|
||||||
|
@ -121,7 +125,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
||||||
audio_norm, filename = self.get_audio_for_index(index)
|
audio_norm, filename = self.get_audio_for_index(index)
|
||||||
alt_files, actual_samples = self.get_related_audio_for_index(index)
|
alt_files, actual_samples = self.get_related_audio_for_index(index)
|
||||||
except:
|
except:
|
||||||
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
if self.debug_loading_failures:
|
||||||
|
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
||||||
return self[index+1]
|
return self[index+1]
|
||||||
|
|
||||||
# When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their
|
# When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their
|
||||||
|
|
83
codes/models/gpt_voice/voice_voice_clip.py
Normal file
83
codes/models/gpt_voice/voice_voice_clip.py
Normal file
|
@ -0,0 +1,83 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import einsum
|
||||||
|
|
||||||
|
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||||
|
from models.lucidrains.dalle.transformer import Transformer
|
||||||
|
from trainer.networks import register_model
|
||||||
|
from utils.util import opt_get
|
||||||
|
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
|
def masked_mean(t, mask, dim=1):
|
||||||
|
t = t.masked_fill(~mask[:, :, None], 0.)
|
||||||
|
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
||||||
|
|
||||||
|
|
||||||
|
class VoiceCLIP(nn.Module):
|
||||||
|
"""
|
||||||
|
CLIP model modified to produce similarity scores from different views of the same audio clip.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_output=512,
|
||||||
|
dim_latent=512,
|
||||||
|
speech_max_seq_len=250,
|
||||||
|
mel_compression_ratio=256,
|
||||||
|
pretrained_encoder_dict_path=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.encoder = AudioMiniEncoder(80, encoder_output)
|
||||||
|
if pretrained_encoder_dict_path is not None:
|
||||||
|
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
|
||||||
|
self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False)
|
||||||
|
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||||
|
self.mel_compression_ratio = mel_compression_ratio
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
speech_mels,
|
||||||
|
speech_lengths,
|
||||||
|
return_loss=True
|
||||||
|
):
|
||||||
|
half_length = min(speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2
|
||||||
|
half_length = (half_length // 4) * 4 # Must be a multiple of 4.
|
||||||
|
|
||||||
|
first_half = speech_mels[:, :, :half_length]
|
||||||
|
second_half = speech_mels[:, :, half_length:half_length*2]
|
||||||
|
|
||||||
|
first_emb = self.encoder(first_half)
|
||||||
|
first_latents = self.to_latent(first_emb)
|
||||||
|
second_emb = self.encoder(second_half)
|
||||||
|
second_latents = self.to_latent(second_emb)
|
||||||
|
|
||||||
|
first_latents, second_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (first_latents, second_latents))
|
||||||
|
|
||||||
|
temp = self.temperature.exp()
|
||||||
|
|
||||||
|
if not return_loss:
|
||||||
|
sim = einsum('n d, n d -> n', first_latents, second_latents) * temp
|
||||||
|
return sim
|
||||||
|
|
||||||
|
sim = einsum('i d, j d -> i j', first_latents, second_latents) * temp
|
||||||
|
labels = torch.arange(first_latents.shape[0], device=first_latents.device)
|
||||||
|
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def register_voice_to_voice_clip(opt_net, opt):
|
||||||
|
return VoiceCLIP(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
clip = VoiceCLIP()
|
||||||
|
clip(torch.randn((2,80,200)),
|
||||||
|
torch.randint(0,200*1024,(2,)),
|
||||||
|
return_loss=True)
|
|
@ -257,14 +257,14 @@ class Trainer:
|
||||||
import wandb
|
import wandb
|
||||||
wandb.log(eval_dict)
|
wandb.log(eval_dict)
|
||||||
|
|
||||||
|
|
||||||
def do_training(self):
|
def do_training(self):
|
||||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
if opt['dist']:
|
if opt['dist']:
|
||||||
self.train_sampler.set_epoch(epoch)
|
self.train_sampler.set_epoch(epoch)
|
||||||
tq_ldr = tqdm(self.train_loader) if self.rank == 0 else self.train_loader
|
|
||||||
|
tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader
|
||||||
|
|
||||||
_t = time()
|
_t = time()
|
||||||
for train_data in tq_ldr:
|
for train_data in tq_ldr:
|
||||||
|
@ -286,7 +286,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_voice_voice_clip.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user