diff --git a/codes/data/audio/unsupervised_audio_dataset.py b/codes/data/audio/unsupervised_audio_dataset.py index 5e74dfef..468afb94 100644 --- a/codes/data/audio/unsupervised_audio_dataset.py +++ b/codes/data/audio/unsupervised_audio_dataset.py @@ -69,6 +69,7 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): if self.pad_to is not None: self.pad_to *= self.sampling_rate self.pad_to = opt_get(opt, ['pad_to_samples'], self.pad_to) + self.min_length = opt_get(opt, ['min_length'], 0) # "Resampled clip" is audio data pulled from the basis of "clip" but with randomly different bounds. There are no # guarantees that "clip_resampled" is different from "clip": in fact, if "clip" is less than pad_to_seconds/samples, @@ -79,9 +80,12 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): self.extra_sample_len = opt_get(opt, ['extra_sample_length'], 2) self.extra_sample_len *= self.sampling_rate + self.debug_loading_failures = opt_get(opt, ['debug_loading_failures'], True) + def get_audio_for_index(self, index): audiopath = self.audiopaths[index] audio = load_audio(audiopath, self.sampling_rate) + assert audio.shape[1] > self.min_length return audio, audiopath def get_related_audio_for_index(self, index): @@ -121,7 +125,8 @@ class UnsupervisedAudioDataset(torch.utils.data.Dataset): audio_norm, filename = self.get_audio_for_index(index) alt_files, actual_samples = self.get_related_audio_for_index(index) except: - print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}") + if self.debug_loading_failures: + print(f"Error loading audio for file {self.audiopaths[index]} {sys.exc_info()}") return self[index+1] # When generating resampled clips, skew is a bias that tries to spread them out from each other, reducing their diff --git a/codes/models/gpt_voice/voice_clip.py b/codes/models/gpt_voice/text_voice_clip.py similarity index 100% rename from codes/models/gpt_voice/voice_clip.py rename to codes/models/gpt_voice/text_voice_clip.py diff --git a/codes/models/gpt_voice/voice_voice_clip.py b/codes/models/gpt_voice/voice_voice_clip.py new file mode 100644 index 00000000..83094271 --- /dev/null +++ b/codes/models/gpt_voice/voice_voice_clip.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import einsum + +from models.gpt_voice.mini_encoder import AudioMiniEncoder +from models.lucidrains.dalle.transformer import Transformer +from trainer.networks import register_model +from utils.util import opt_get + + +def exists(val): + return val is not None + + +def masked_mean(t, mask, dim=1): + t = t.masked_fill(~mask[:, :, None], 0.) + return t.sum(dim = 1) / mask.sum(dim = 1)[..., None] + + +class VoiceCLIP(nn.Module): + """ + CLIP model modified to produce similarity scores from different views of the same audio clip. + """ + + def __init__( + self, + encoder_output=512, + dim_latent=512, + speech_max_seq_len=250, + mel_compression_ratio=256, + pretrained_encoder_dict_path=None + ): + super().__init__() + self.encoder = AudioMiniEncoder(80, encoder_output) + if pretrained_encoder_dict_path is not None: + self.encoder.load_state_dict(torch.load(pretrained_encoder_dict_path)) + self.to_latent = nn.Linear(encoder_output, dim_latent, bias=False) + self.temperature = nn.Parameter(torch.tensor(1.)) + self.mel_compression_ratio = mel_compression_ratio + + def forward( + self, + speech_mels, + speech_lengths, + return_loss=True + ): + half_length = min(speech_mels.shape[-1], torch.min(speech_lengths).item() // self.mel_compression_ratio) // 2 + half_length = (half_length // 4) * 4 # Must be a multiple of 4. + + first_half = speech_mels[:, :, :half_length] + second_half = speech_mels[:, :, half_length:half_length*2] + + first_emb = self.encoder(first_half) + first_latents = self.to_latent(first_emb) + second_emb = self.encoder(second_half) + second_latents = self.to_latent(second_emb) + + first_latents, second_latents = map(lambda t: F.normalize(t, p=2, dim=-1), (first_latents, second_latents)) + + temp = self.temperature.exp() + + if not return_loss: + sim = einsum('n d, n d -> n', first_latents, second_latents) * temp + return sim + + sim = einsum('i d, j d -> i j', first_latents, second_latents) * temp + labels = torch.arange(first_latents.shape[0], device=first_latents.device) + loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 + return loss + + +@register_model +def register_voice_to_voice_clip(opt_net, opt): + return VoiceCLIP(**opt_get(opt_net, ['kwargs'], {})) + + +if __name__ == '__main__': + clip = VoiceCLIP() + clip(torch.randn((2,80,200)), + torch.randint(0,200*1024,(2,)), + return_loss=True) \ No newline at end of file diff --git a/codes/train.py b/codes/train.py index ec29b996..03db0ee8 100644 --- a/codes/train.py +++ b/codes/train.py @@ -257,14 +257,14 @@ class Trainer: import wandb wandb.log(eval_dict) - def do_training(self): self.logger.info('Start training from epoch: {:d}, iter: {:d}'.format(self.start_epoch, self.current_step)) for epoch in range(self.start_epoch, self.total_epochs + 1): self.epoch = epoch if opt['dist']: self.train_sampler.set_epoch(epoch) - tq_ldr = tqdm(self.train_loader) if self.rank == 0 else self.train_loader + + tq_ldr = tqdm(self.train_loader) if self.rank <= 0 else self.train_loader _t = time() for train_data in tq_ldr: @@ -286,7 +286,7 @@ class Trainer: if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_gpt_asr_mass_hf2.yml') + parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_voice_voice_clip.yml') parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) args = parser.parse_args()