From 1d79b44aeff56ad3cd6d081aaa8b163dc89e8a8d Mon Sep 17 00:00:00 2001 From: James Betker Date: Tue, 26 Apr 2022 10:24:03 -0600 Subject: [PATCH] is this from tortoise? --- api.py | 21 ++++++++++++++++++++- is_this_from_tortoise.py | 14 ++++++++++++++ models/classifier.py | 23 ++++++++++++++--------- 3 files changed, 48 insertions(+), 10 deletions(-) create mode 100644 is_this_from_tortoise.py diff --git a/api.py b/api.py index 321e154..557d557 100644 --- a/api.py +++ b/api.py @@ -8,6 +8,7 @@ import torch.nn.functional as F import progressbar import torchaudio +from models.classifier import AudioMiniEncoderWithClassifierHead from models.cvvp import CVVP from models.diffusion_decoder import DiffusionTts from models.autoregressive import UnifiedVoice @@ -24,7 +25,7 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance pbar = None -def download_models(): +def download_models(specific_models=None): """ Call to download all the models that Tortoise uses. """ @@ -49,6 +50,8 @@ def download_models(): pbar.finish() pbar = None for model_name, url in MODELS.items(): + if specific_models is not None and model_name not in specific_models: + continue if os.path.exists(f'.models/{model_name}'): continue print(f'Downloading {model_name} from {url}...') @@ -144,6 +147,22 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_sa return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] +def classify_audio_clip(clip): + """ + Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise. + :param clip: torch tensor containing audio waveform data (get it from load_audio) + :return: True if the clip was classified as coming from Tortoise and false if it was classified as real. + """ + download_models(['classifier']) + classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4, + resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32, + dropout=0, kernel_size=5, distribute_zero_label=False) + classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu'))) + clip = clip.cpu().unsqueeze(0) + results = F.softmax(classifier(clip), dim=-1) + return results[0][0] + + class TextToSpeech: """ Main entry point into Tortoise. diff --git a/is_this_from_tortoise.py b/is_this_from_tortoise.py new file mode 100644 index 0000000..550b33e --- /dev/null +++ b/is_this_from_tortoise.py @@ -0,0 +1,14 @@ +import argparse + +from api import classify_audio_clip +from utils.audio import load_audio + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="results/favorite_riding_hood.mp3") + args = parser.parse_args() + + clip = load_audio(args.clip, 24000) + clip = clip[:, :220000] + prob = classify_audio_clip(clip) + print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.") \ No newline at end of file diff --git a/models/classifier.py b/models/classifier.py index ff39daa..c899773 100644 --- a/models/classifier.py +++ b/models/classifier.py @@ -1,4 +1,9 @@ import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock class ResBlock(nn.Module): @@ -27,7 +32,7 @@ class ResBlock(nn.Module): self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), - conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), + nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding), ) self.updown = up or down @@ -46,18 +51,18 @@ class ResBlock(nn.Module): nn.SiLU(), nn.Dropout(p=dropout), zero_module( - conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) + nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: - self.skip_connection = conv_nd( + self.skip_connection = nn.Conv1d( dims, channels, self.out_channels, kernel_size, padding=padding ) else: - self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1) def forward(self, x): if self.do_checkpoint: @@ -94,21 +99,21 @@ class AudioMiniEncoder(nn.Module): kernel_size=3): super().__init__() self.init = nn.Sequential( - conv_nd(1, spec_dim, base_channels, 3, padding=1) + nn.Conv1d(spec_dim, base_channels, 3, padding=1) ) ch = base_channels res = [] self.layers = depth for l in range(depth): for r in range(resnet_blocks): - res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size)) - res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor)) + res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size)) + res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor)) ch *= 2 self.res = nn.Sequential(*res) self.final = nn.Sequential( normalization(ch), nn.SiLU(), - conv_nd(1, ch, embedding_dim, 1) + nn.Conv1d(ch, embedding_dim, 1) ) attn = [] for a in range(attn_blocks): @@ -118,7 +123,7 @@ class AudioMiniEncoder(nn.Module): def forward(self, x): h = self.init(x) - h = sequential_checkpoint(self.res, self.layers, h) + h = self.res(h) h = self.final(h) for blk in self.attn: h = checkpoint(blk, h)