forked from mrq/tortoise-tts
is this from tortoise?
This commit is contained in:
parent
1f8eef2807
commit
1d79b44aef
21
api.py
21
api.py
|
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||||
import progressbar
|
import progressbar
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
|
||||||
|
from models.classifier import AudioMiniEncoderWithClassifierHead
|
||||||
from models.cvvp import CVVP
|
from models.cvvp import CVVP
|
||||||
from models.diffusion_decoder import DiffusionTts
|
from models.diffusion_decoder import DiffusionTts
|
||||||
from models.autoregressive import UnifiedVoice
|
from models.autoregressive import UnifiedVoice
|
||||||
|
@ -24,7 +25,7 @@ from utils.tokenizer import VoiceBpeTokenizer, lev_distance
|
||||||
pbar = None
|
pbar = None
|
||||||
|
|
||||||
|
|
||||||
def download_models():
|
def download_models(specific_models=None):
|
||||||
"""
|
"""
|
||||||
Call to download all the models that Tortoise uses.
|
Call to download all the models that Tortoise uses.
|
||||||
"""
|
"""
|
||||||
|
@ -49,6 +50,8 @@ def download_models():
|
||||||
pbar.finish()
|
pbar.finish()
|
||||||
pbar = None
|
pbar = None
|
||||||
for model_name, url in MODELS.items():
|
for model_name, url in MODELS.items():
|
||||||
|
if specific_models is not None and model_name not in specific_models:
|
||||||
|
continue
|
||||||
if os.path.exists(f'.models/{model_name}'):
|
if os.path.exists(f'.models/{model_name}'):
|
||||||
continue
|
continue
|
||||||
print(f'Downloading {model_name} from {url}...')
|
print(f'Downloading {model_name} from {url}...')
|
||||||
|
@ -144,6 +147,22 @@ def do_spectrogram_diffusion(diffusion_model, diffuser, latents, conditioning_sa
|
||||||
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
return denormalize_tacotron_mel(mel)[:,:,:output_seq_len]
|
||||||
|
|
||||||
|
|
||||||
|
def classify_audio_clip(clip):
|
||||||
|
"""
|
||||||
|
Returns whether or not Tortoises' classifier thinks the given clip came from Tortoise.
|
||||||
|
:param clip: torch tensor containing audio waveform data (get it from load_audio)
|
||||||
|
:return: True if the clip was classified as coming from Tortoise and false if it was classified as real.
|
||||||
|
"""
|
||||||
|
download_models(['classifier'])
|
||||||
|
classifier = AudioMiniEncoderWithClassifierHead(2, spec_dim=1, embedding_dim=512, depth=5, downsample_factor=4,
|
||||||
|
resnet_blocks=2, attn_blocks=4, num_attn_heads=4, base_channels=32,
|
||||||
|
dropout=0, kernel_size=5, distribute_zero_label=False)
|
||||||
|
classifier.load_state_dict(torch.load('.models/classifier.pth', map_location=torch.device('cpu')))
|
||||||
|
clip = clip.cpu().unsqueeze(0)
|
||||||
|
results = F.softmax(classifier(clip), dim=-1)
|
||||||
|
return results[0][0]
|
||||||
|
|
||||||
|
|
||||||
class TextToSpeech:
|
class TextToSpeech:
|
||||||
"""
|
"""
|
||||||
Main entry point into Tortoise.
|
Main entry point into Tortoise.
|
||||||
|
|
14
is_this_from_tortoise.py
Normal file
14
is_this_from_tortoise.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from api import classify_audio_clip
|
||||||
|
from utils.audio import load_audio
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--clip', type=str, help='Path to an audio clip to classify.', default="results/favorite_riding_hood.mp3")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
clip = load_audio(args.clip, 24000)
|
||||||
|
clip = clip[:, :220000]
|
||||||
|
prob = classify_audio_clip(clip)
|
||||||
|
print(f"This classifier thinks there is a {prob*100}% chance that this clip was generated from Tortoise.")
|
|
@ -1,4 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
|
||||||
|
from models.arch_util import Upsample, Downsample, normalization, zero_module, AttentionBlock
|
||||||
|
|
||||||
|
|
||||||
class ResBlock(nn.Module):
|
class ResBlock(nn.Module):
|
||||||
|
@ -27,7 +32,7 @@ class ResBlock(nn.Module):
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
normalization(channels),
|
normalization(channels),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding),
|
nn.Conv1d(channels, self.out_channels, kernel_size, padding=padding),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.updown = up or down
|
self.updown = up or down
|
||||||
|
@ -46,18 +51,18 @@ class ResBlock(nn.Module):
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
zero_module(
|
||||||
conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding)
|
nn.Conv1d(self.out_channels, self.out_channels, kernel_size, padding=padding)
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.out_channels == channels:
|
if self.out_channels == channels:
|
||||||
self.skip_connection = nn.Identity()
|
self.skip_connection = nn.Identity()
|
||||||
elif use_conv:
|
elif use_conv:
|
||||||
self.skip_connection = conv_nd(
|
self.skip_connection = nn.Conv1d(
|
||||||
dims, channels, self.out_channels, kernel_size, padding=padding
|
dims, channels, self.out_channels, kernel_size, padding=padding
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
self.skip_connection = nn.Conv1d(dims, channels, self.out_channels, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.do_checkpoint:
|
if self.do_checkpoint:
|
||||||
|
@ -94,21 +99,21 @@ class AudioMiniEncoder(nn.Module):
|
||||||
kernel_size=3):
|
kernel_size=3):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.init = nn.Sequential(
|
self.init = nn.Sequential(
|
||||||
conv_nd(1, spec_dim, base_channels, 3, padding=1)
|
nn.Conv1d(spec_dim, base_channels, 3, padding=1)
|
||||||
)
|
)
|
||||||
ch = base_channels
|
ch = base_channels
|
||||||
res = []
|
res = []
|
||||||
self.layers = depth
|
self.layers = depth
|
||||||
for l in range(depth):
|
for l in range(depth):
|
||||||
for r in range(resnet_blocks):
|
for r in range(resnet_blocks):
|
||||||
res.append(ResBlock(ch, dropout, dims=1, do_checkpoint=False, kernel_size=kernel_size))
|
res.append(ResBlock(ch, dropout, do_checkpoint=False, kernel_size=kernel_size))
|
||||||
res.append(Downsample(ch, use_conv=True, dims=1, out_channels=ch*2, factor=downsample_factor))
|
res.append(Downsample(ch, use_conv=True, out_channels=ch*2, factor=downsample_factor))
|
||||||
ch *= 2
|
ch *= 2
|
||||||
self.res = nn.Sequential(*res)
|
self.res = nn.Sequential(*res)
|
||||||
self.final = nn.Sequential(
|
self.final = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
conv_nd(1, ch, embedding_dim, 1)
|
nn.Conv1d(ch, embedding_dim, 1)
|
||||||
)
|
)
|
||||||
attn = []
|
attn = []
|
||||||
for a in range(attn_blocks):
|
for a in range(attn_blocks):
|
||||||
|
@ -118,7 +123,7 @@ class AudioMiniEncoder(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = self.init(x)
|
h = self.init(x)
|
||||||
h = sequential_checkpoint(self.res, self.layers, h)
|
h = self.res(h)
|
||||||
h = self.final(h)
|
h = self.final(h)
|
||||||
for blk in self.attn:
|
for blk in self.attn:
|
||||||
h = checkpoint(blk, h)
|
h = checkpoint(blk, h)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user