is this from tortoise?

This commit is contained in:
James Betker 2022-04-26 10:24:03 -06:00
parent 1f8eef2807
commit 1d79b44aef
3 changed files with 48 additions and 10 deletions

21
api.py
View File

@ -8,6 +8,7 @@ import torch.nn.functional as F
import progressbar import progressbar
import torchaudio import torchaudio
from models.classifier import AudioMiniEncoderWithClassifierHead
from models.cvvp import CVVP from models.cvvp import CVVP
from models.diffusion_decoder import DiffusionTts from models.diffusion_decoder import DiffusionTts
from models.autoregressive import UnifiedVoice from models.autoregressive import UnifiedVoice
@ -24,7 +25,7 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
pbar = None pbar = None
def download_models(): def download_models(specific_models=None):
""" """
Call to download all the models that Tortoise uses. Call to download all the models that Tortoise uses.
""" """
@ -49,6 +50,8 @@ def download_models():
pbar.finish() pbar.finish()
pbar = None pbar = None
for model_name, url in MODELS.items(): for model_name, url in MODELS.items():
if specific_models is not None and model_name not in specific_models:
continue
if os.path.exists(f'.models/{model_name}'): if os.path.exists(f'.models/{model_name}'):
continue continue
print(f'Downloading {model_name} from {url}...') print(f'Downloading {model_name} from {url}...')
@ -144,6 +147,22 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_sa
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len] return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
def classify_audio_clip(clip):
"""
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
:param clip: torch tensor containing audio waveform data (get it from load_audio)
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
"""
download_models(['classifier'])
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
dropout=0, kernel_size=5, distribute_zero_label=False)
classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
clip = clip.cpu().unsqueeze(0)
results = F.softmax(classifier(clip), dim=-1)
return results[0][0]
class TextToSpeech: class TextToSpeech:
""" """
Main entry point into Tortoise. Main entry point into Tortoise.

14
is_this_from_tortoise.py Normal file
View File

@ -0,0 +1,14 @@
import argparse
from api import classify_audio_clip
from utils.audio import load_audio
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="results/favorite_riding_hood.mp3")
args = parser.parse_args()
clip = load_audio(args.clip, 24000)
clip = clip[:, :220000]
prob = classify_audio_clip(clip)
print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.")

View File

@ -1,4 +1,9 @@
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
class ResBlock(nn.Module): class ResBlock(nn.Module):
@ -27,7 +32,7 @@ class ResBlock(nn.Module):
self.in_layers = nn.Sequential( self.in_layers = nn.Sequential(
normalization(channels), normalization(channels),
nn.SiLU(), nn.SiLU(),
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding), nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
) )
self.updown = up or down self.updown = up or down
@ -46,18 +51,18 @@ class ResBlock(nn.Module):
nn.SiLU(), nn.SiLU(),
nn.Dropout(p=dropout), nn.Dropout(p=dropout),
zero_module( zero_module(
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding) nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
), ),
) )
if self.out_channels == channels: if self.out_channels == channels:
self.skip_connection = nn.Identity() self.skip_connection = nn.Identity()
elif use_conv: elif use_conv:
self.skip_connection = conv_nd( self.skip_connection = nn.Conv1d(
dims, channels, self.out_channels, kernel_size, padding=padding dims, channels, self.out_channels, kernel_size, padding=padding
) )
else: else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
def forward(self, x): def forward(self, x):
if self.do_checkpoint: if self.do_checkpoint:
@ -94,21 +99,21 @@ class AudioMiniEncoder(nn.Module):
kernel_size=3): kernel_size=3):
super().__init__() super().__init__()
self.init = nn.Sequential( self.init = nn.Sequential(
conv_nd(1, spec_dim, base_channels, 3, padding=1) nn.Conv1d(spec_dim, base_channels, 3, padding=1)
) )
ch = base_channels ch = base_channels
res = [] res = []
self.layers = depth self.layers = depth
for l in range(depth): for l in range(depth):
for r in range(resnet_blocks): for r in range(resnet_blocks):
res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size)) res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor)) res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
ch *= 2 ch *= 2
self.res = nn.Sequential(*res) self.res = nn.Sequential(*res)
self.final = nn.Sequential( self.final = nn.Sequential(
normalization(ch), normalization(ch),
nn.SiLU(), nn.SiLU(),
conv_nd(1, ch, embedding_dim, 1) nn.Conv1d(ch, embedding_dim, 1)
) )
attn = [] attn = []
for a in range(attn_blocks): for a in range(attn_blocks):
@ -118,7 +123,7 @@ class AudioMiniEncoder(nn.Module):
def forward(self, x): def forward(self, x):
h = self.init(x) h = self.init(x)
h = sequential_checkpoint(self.res, self.layers, h) h = self.res(h)
h = self.final(h) h = self.final(h)
for blk in self.attn: for blk in self.attn:
h = checkpoint(blk, h) h = checkpoint(blk, h)