forked from mrq/DL-Art-School
Add voice2voice clip model
This commit is contained in:
parent
a9ee5b624f
commit
07c2b9907c
|
@ -69,6 +69,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
if self.pad_to is not None:
|
||||
self.pad_to *= self.sampling_rate
|
||||
self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to)
|
||||
self.min_length = opt_get(opt, ['min_length'], 0)
|
||||
|
||||
# "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no
|
||||
# guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples,
|
||||
|
@ -79,9 +80,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2)
|
||||
self.extra_sample_len *= self.sampling_rate
|
||||
|
||||
self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True)
|
||||
|
||||
def get_audio_for_index(self, index):
|
||||
audiopath = self.audiopaths[index]
|
||||
audio = load_audio(audiopath, self.sampling_rate)
|
||||
assert audio.shape[1] > self.min_length
|
||||
return audio, audiopath
|
||||
|
||||
def get_related_audio_for_index(self, index):
|
||||
|
@ -121,7 +125,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset):
|
|||
audio_norm, filename = self.get_audio_for_index(index)
|
||||
alt_files, actual_samples = self.get_related_audio_for_index(index)
|
||||
except:
|
||||
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
||||
if self.debug_loading_failures:
|
||||
print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}")
|
||||
return self[index+1]
|
||||
|
||||
# When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their
|
||||
|
|
83
codes/models/gpt_voice/voice_voice_clip.py
Normal file
83
codes/models/gpt_voice/voice_voice_clip.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.lucidrains.dalle.transformer import Transformer
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def masked_mean(t, mask, dim=1):
|
||||
t = t.masked_fill(~mask[:, :, None], 0.)
|
||||
return t.sum(dim = 1) / mask.sum(dim = 1)[..., None]
|
||||
|
||||
|
||||
class VoiceCLIP(nn.Module):
|
||||
"""
|
||||
CLIP model modified to produce similarity scores from different views of the same audio clip.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_output=512,
|
||||
dim_latent=512,
|
||||
speech_max_seq_len=250,
|
||||
mel_compression_ratio=256,
|
||||
pretrained_encoder_dict_path=None
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = AudioMiniEncoder(80, encoder_output)
|
||||
if pretrained_encoder_dict_path is not None:
|
||||
self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path))
|
||||
self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False)
|
||||
self.temperature = nn.Parameter(torch.tensor(1.))
|
||||
self.mel_compression_ratio = mel_compression_ratio
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech_mels,
|
||||
speech_lengths,
|
||||
return_loss=True
|
||||
):
|
||||
half_length = min(speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2
|
||||
half_length = (half_length // 4) * 4 # Must be a multiple of 4.
|
||||
|
||||
first_half = speech_mels[:, :, :half_length]
|
||||
second_half = speech_mels[:, :, half_length:half_length*2]
|
||||
|
||||
first_emb = self.encoder(first_half)
|
||||
first_latents = self.to_latent(first_emb)
|
||||
second_emb = self.encoder(second_half)
|
||||
second_latents = self.to_latent(second_emb)
|
||||
|
||||
first_latents, second_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (first_latents, second_latents))
|
||||
|
||||
temp = self.temperature.exp()
|
||||
|
||||
if not return_loss:
|
||||
sim = einsum('n d, n d -> n', first_latents, second_latents) * temp
|
||||
return sim
|
||||
|
||||
sim = einsum('i d, j d -> i j', first_latents, second_latents) * temp
|
||||
labels = torch.arange(first_latents.shape[0], device=first_latents.device)
|
||||
loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2
|
||||
return loss
|
||||
|
||||
|
||||
@register_model
|
||||
def register_voice_to_voice_clip(opt_net, opt):
|
||||
return VoiceCLIP(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
clip = VoiceCLIP()
|
||||
clip(torch.randn((2,80,200)),
|
||||
torch.randint(0,200*1024,(2,)),
|
||||
return_loss=True)
|
|
@ -257,14 +257,14 @@ class Trainer:
|
|||
import wandb
|
||||
wandb.log(eval_dict)
|
||||
|
||||
|
||||
def do_training(self):
|
||||
self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step))
|
||||
for epoch in range(self.start_epoch, self.total_epochs + 1):
|
||||
self.epoch = epoch
|
||||
if opt['dist']:
|
||||
self.train_sampler.set_epoch(epoch)
|
||||
tq_ldr = tqdm(self.train_loader) if self.rank == 0 else self.train_loader
|
||||
|
||||
tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader
|
||||
|
||||
_t = time()
|
||||
for train_data in tq_ldr:
|
||||
|
@ -286,7 +286,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_voice_voice_clip.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user