forked from mrq/DL-Art-School
Refactor audio-style models into the audio folder
This commit is contained in:
parent
f95d3d2b82
commit
7929fd89de
|
@ -1,5 +1,4 @@
|
|||
"""create dataset and dataloader"""
|
||||
import logging
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from munch import munchify
|
||||
|
@ -64,7 +63,7 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
elif mode == 'nv_tacotron':
|
||||
from data.audio.nv_tacotron_dataset import TextWavLoader as D
|
||||
from data.audio.nv_tacotron_dataset import TextMelCollate as C
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
|
@ -72,13 +71,13 @@ def create_dataset(dataset_opt, return_collate=False):
|
|||
collate = C()
|
||||
elif mode == 'paired_voice_audio':
|
||||
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
elif mode == 'fast_paired_voice_audio':
|
||||
from data.audio.fast_paired_dataset import FastPairedVoiceDataset as D
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(dataset_opt)
|
||||
dataset_opt = munchify(default_params)
|
||||
|
|
|
@ -6,9 +6,9 @@ import torch.utils.data
|
|||
from torch import LongTensor
|
||||
from tqdm import tqdm
|
||||
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text import symbols
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2 import symbols
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
|
||||
|
||||
class GptTtsDataset(torch.utils.data.Dataset):
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
import os
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -9,19 +6,15 @@ import torch.utils.data
|
|||
import torchaudio
|
||||
from munch import munchify
|
||||
from tqdm import tqdm
|
||||
from transformers import GPT2TokenizerFast
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, UnsupervisedAudioDataset
|
||||
from data.audio.unsupervised_audio_dataset import UnsupervisedAudioDataset
|
||||
from data.text.hf_datasets_wrapper import HfDataset
|
||||
from data.util import find_files_of_type, is_audio_file
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
def build_paired_voice_dataset(args):
|
||||
from data.audio.paired_voice_audio_dataset import TextWavLoader as D
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.audio.tts.tacotron2 import create_hparams
|
||||
default_params = create_hparams()
|
||||
default_params.update(args)
|
||||
dataset_opt = munchify(default_params)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
@ -10,8 +9,8 @@ from tqdm import tqdm
|
|||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from data.util import find_files_of_type, is_audio_file
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text import text_to_sequence
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import os
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
|
@ -10,8 +9,8 @@ import torchaudio
|
|||
from tqdm import tqdm
|
||||
|
||||
from data.audio.unsupervised_audio_dataset import load_audio, load_similar_clips
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text import text_to_sequence, sequence_to_text
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2 import text_to_sequence, sequence_to_text
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import os
|
||||
import pathlib
|
||||
import random
|
||||
import sys
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
|
@ -11,9 +9,8 @@ import torchaudio
|
|||
from audio2numpy import open_audio
|
||||
from tqdm import tqdm
|
||||
|
||||
from data.audio.wav_aug import WavAugmentor
|
||||
from data.util import find_files_of_type, is_wav_file, is_audio_file, load_paths_from_cache
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from data.util import find_files_of_type, is_audio_file, load_paths_from_cache
|
||||
from models.audio.tts.tacotron2 import load_wav_to_torch
|
||||
from utils.util import opt_get
|
||||
|
||||
|
||||
|
@ -189,7 +186,7 @@ if __name__ == '__main__':
|
|||
'extra_samples': 4,
|
||||
'resample_clip': True,
|
||||
}
|
||||
from data import create_dataset, create_dataloader, util
|
||||
from data import create_dataset, create_dataloader
|
||||
|
||||
ds = create_dataset(params)
|
||||
dl = create_dataloader(ds, params)
|
||||
|
|
|
@ -1,16 +1,14 @@
|
|||
import re
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
from tokenizers import Tokenizer
|
||||
from tokenizers.models import BPE
|
||||
from tokenizers.pre_tokenizers import Whitespace
|
||||
from tokenizers.processors import ByteLevel
|
||||
from tokenizers.trainers import BpeTrainer
|
||||
|
||||
from data.audio.paired_voice_audio_dataset import load_mozilla_cv, load_voxpopuli, load_tsv
|
||||
from models.tacotron2.taco_utils import load_filepaths_and_text
|
||||
from models.tacotron2.text.cleaners import english_cleaners
|
||||
from models.audio.tts.tacotron2 import load_filepaths_and_text
|
||||
from models.audio.tts.tacotron2.text.cleaners import english_cleaners
|
||||
|
||||
|
||||
def remove_extraneous_punctuation(word):
|
||||
|
|
|
@ -3,7 +3,7 @@ import random
|
|||
import torch
|
||||
import torchaudio.sox_effects
|
||||
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from models.audio.tts.tacotron2 import load_wav_to_torch
|
||||
|
||||
|
||||
# Returns random double on [l,h] as a string
|
||||
|
|
|
@ -3,11 +3,11 @@ import json
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from x_transformers import Encoder, XTransformer, TransformerWrapper
|
||||
from x_transformers import Encoder, TransformerWrapper
|
||||
|
||||
from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer
|
||||
from models.gpt_voice.unified_voice2 import ConditioningEncoder
|
||||
from models.tacotron2.text.cleaners import english_cleaners
|
||||
from models.audio.tts.unet_diffusion_tts6 import CheckpointedLayer
|
||||
from models.audio.tts.unified_voice2 import ConditioningEncoder
|
||||
from models.audio.tts.tacotron2.text.cleaners import english_cleaners
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
|
@ -4,17 +4,11 @@ import json
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import T5Config, T5Model, T5PreTrainedModel, T5ForConditionalGeneration
|
||||
from transformers.file_utils import replace_return_docstrings
|
||||
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
from x_transformers import Encoder, XTransformer
|
||||
from transformers import T5Config, T5ForConditionalGeneration
|
||||
|
||||
from models.gpt_voice.transformer_builders import null_position_embeddings
|
||||
from models.gpt_voice.unet_diffusion_tts6 import CheckpointedLayer
|
||||
from models.gpt_voice.unified_voice2 import ConditioningEncoder
|
||||
from models.tacotron2.text.cleaners import english_cleaners
|
||||
from models.audio.tts.transformer_builders import null_position_embeddings
|
||||
from models.audio.tts.unified_voice2 import ConditioningEncoder
|
||||
from models.audio.tts.tacotron2.text.cleaners import english_cleaners
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
||||
|
@ -146,7 +140,6 @@ def inf():
|
|||
model.load_state_dict(sd)
|
||||
raw_batch = torch.load('raw_batch.pth')
|
||||
with torch.no_grad():
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from scripts.audio.gen.speech_synthesis_utils import wav_to_mel
|
||||
ref_mel = torch.cat([wav_to_mel(raw_batch['conditioning'][0])[:, :, :256],
|
||||
wav_to_mel(raw_batch['conditioning'][0])[:, :, :256]], dim=0).unsqueeze(0)
|
|
@ -1,5 +1,5 @@
|
|||
#import tensorflow as tf
|
||||
from models.tacotron2.text import symbols
|
||||
from models.audio.tts.tacotron2.text import symbols
|
||||
|
||||
|
||||
def create_hparams(hparams_string=None, verbose=False):
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from models.tacotron2.audio_processing import dynamic_range_compression
|
||||
from models.tacotron2.audio_processing import dynamic_range_decompression
|
||||
from models.tacotron2.stft import STFT
|
||||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_compression
|
||||
from models.audio.tts.tacotron2.audio_processing import dynamic_range_decompression
|
||||
from models.audio.tts.tacotron2.stft import STFT
|
||||
|
||||
|
||||
class LinearNorm(torch.nn.Module):
|
|
@ -36,7 +36,7 @@ import torch.nn.functional as F
|
|||
from torch.autograd import Variable
|
||||
from scipy.signal import get_window
|
||||
from librosa.util import pad_center, tiny
|
||||
from models.tacotron2.audio_processing import window_sumsquare
|
||||
from models.audio.tts.tacotron2.audio_processing import window_sumsquare
|
||||
|
||||
|
||||
class STFT(torch.nn.Module):
|
|
@ -4,11 +4,10 @@ from munch import munchify
|
|||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from models.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.audio.tts.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from models.audio.tts.tacotron2.hparams import create_hparams
|
||||
from trainer.networks import register_model
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from utils.util import opt_get, checkpoint
|
||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
|
||||
|
||||
class LocationLayer(nn.Module):
|
|
@ -3,8 +3,8 @@ import re
|
|||
|
||||
import torch
|
||||
|
||||
from models.tacotron2.text import cleaners
|
||||
from models.tacotron2.text.symbols import symbols
|
||||
from models.audio.tts.tacotron2.text import cleaners
|
||||
from models.audio.tts.tacotron2.text.symbols import symbols
|
||||
|
||||
|
||||
# Mappings from symbol to numeric ID and vice versa:
|
|
@ -4,7 +4,7 @@
|
|||
Defines the set of symbols used in text input to the model.
|
||||
|
||||
The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. '''
|
||||
from models.tacotron2.text import cmudict
|
||||
from models.audio.tts.tacotron2.text import cmudict
|
||||
|
||||
_pad = '_'
|
||||
_punctuation = '!\'(),.:;? '
|
|
@ -3,16 +3,16 @@ import torch
|
|||
from munch import munchify
|
||||
from torch.autograd import Variable
|
||||
from torch import nn
|
||||
from torch.nn import functional as F, Flatten
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.arch_util import ConvGnSilu
|
||||
from models.diffusion.unet_diffusion import UNetModel, AttentionPool2d
|
||||
from models.tacotron2.layers import ConvNorm, LinearNorm
|
||||
from models.tacotron2.hparams import create_hparams
|
||||
from models.tacotron2.tacotron2 import Prenet, Attention, Encoder
|
||||
from models.audio.tts.tacotron2.layers import LinearNorm
|
||||
from models.audio.tts.tacotron2.hparams import create_hparams
|
||||
from models.audio.tts.tacotron2.tacotron2 import Attention, Encoder
|
||||
from trainer.networks import register_model
|
||||
from models.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from utils.util import opt_get, checkpoint
|
||||
from models.audio.tts.tacotron2.taco_utils import get_mask_from_lengths
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
|
|
@ -5,12 +5,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
from x_transformers.x_transformers import AbsolutePositionalEmbedding, AttentionLayers
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
|
@ -1,17 +1,15 @@
|
|||
import functools
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
from x_transformers.x_transformers import AbsolutePositionalEmbedding, AttentionLayers
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
|
@ -1,17 +1,15 @@
|
|||
import functools
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
from x_transformers.x_transformers import AbsolutePositionalEmbedding, AttentionLayers, CrossAttender
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
|
@ -1,20 +1,17 @@
|
|||
import functools
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import autocast
|
||||
from x_transformers.x_transformers import AbsolutePositionalEmbedding, AttentionLayers, CrossAttender
|
||||
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.diffusion.unet_diffusion import TimestepEmbedSequential, \
|
||||
Downsample, Upsample
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint
|
||||
from x_transformers import Encoder, ContinuousTransformerWrapper
|
||||
|
||||
|
|
@ -1,6 +1,4 @@
|
|||
import functools
|
||||
import random
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -11,11 +9,11 @@ from x_transformers import Encoder
|
|||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample, TimestepBlock
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.gpt_voice.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.unet_diffusion_tts7 import CheckpointedXTransformerEncoder
|
||||
from scripts.audio.gen.use_diffuse_tts import ceil_multiple
|
||||
from trainer.networks import register_model
|
||||
from utils.util import checkpoint, opt_get
|
||||
from utils.util import checkpoint
|
||||
|
||||
|
||||
def is_latent(t):
|
|
@ -1,13 +1,11 @@
|
|||
from models.diffusion.fp16_util import convert_module_to_f32, convert_module_to_f16
|
||||
from models.diffusion.nn import timestep_embedding, normalization, zero_module, conv_nd, linear
|
||||
from models.diffusion.unet_diffusion import AttentionPool2d, AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
||||
from models.diffusion.unet_diffusion import AttentionBlock, ResBlock, TimestepEmbedSequential, \
|
||||
Downsample, Upsample
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder, EmbeddingCombiner
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from trainer.networks import register_model
|
||||
from utils.util import get_mask_from_lengths
|
||||
|
||||
|
||||
class DiscreteSpectrogramConditioningBlock(nn.Module):
|
|
@ -1,15 +1,12 @@
|
|||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import GPT2Model, GPT2Config, GPT2PreTrainedModel
|
||||
from transformers import GPT2Config, GPT2PreTrainedModel
|
||||
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
||||
from transformers.utils.model_parallel_utils import get_device_map, assert_device_map
|
||||
|
||||
from models.arch_util import AttentionBlock
|
||||
from models.gpt_voice.transformer_builders import build_hf_gpt_transformer
|
||||
from models.tacotron2.text import symbols
|
||||
from models.audio.tts.transformer_builders import build_hf_gpt_transformer
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
|
@ -2,11 +2,9 @@ import random
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.lucidrains.dalle.transformer import Transformer
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from trainer.injectors.spec_augment import spec_augment
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from x_transformers import Encoder, Decoder, ContinuousTransformerWrapper
|
||||
|
||||
from models.gpt_voice.mini_encoder import AudioMiniEncoder
|
||||
from models.audio.tts.mini_encoder import AudioMiniEncoder
|
||||
from trainer.networks import register_model
|
||||
|
||||
|
||||
|
@ -135,6 +135,31 @@ class Wav2VecMatcher(nn.Module):
|
|||
|
||||
return ce_loss, mse_loss
|
||||
|
||||
def find_matching_w2v_logit(self, key, w2v_logit_iterable):
|
||||
pass
|
||||
|
||||
def sample(self, text_tokens, conditioning_clip, w2v_logit_iterable, audio_clip_iterable):
|
||||
text_emb = self.text_embedding(text_tokens)
|
||||
cond_emb = self.conditioning_encoder(conditioning_clip)
|
||||
enc_inputs = torch.cat([cond_emb.unsqueeze(1), text_emb], dim=1)
|
||||
dec_context = self.encoder(enc_inputs)
|
||||
dec_inputs = self.decoder_start_embedding
|
||||
count = 0
|
||||
while count < 400:
|
||||
dec_out = self.decoder(dec_inputs, context=dec_context)
|
||||
|
||||
# Check if that was EOS.
|
||||
l2 = F.mse_loss(dec_out[:,-1], self.decoder_stop_embedding)
|
||||
if l2 < .1: # TODO: fix threshold.
|
||||
break
|
||||
|
||||
# Find a matching w2v logit from the given iterable.
|
||||
matching_logit_index = self.find_matching_w2v_logit(dec_out[:,-1], w2v_logit_iterable)
|
||||
matching_logit = w2v_logit_iterable[matching_logit_index]
|
||||
dec_inputs = torch.cat([dec_inputs, self.w2v_value_encoder(matching_logit).unsqueeze(1)], dim=1)
|
||||
produced_audio = torch.cat([produced_audio, audio_clip_iterable[matching_logit_index]], dim=-1)
|
||||
return produced_audio
|
||||
|
||||
|
||||
@register_model
|
||||
def register_w2v_matcher(opt_net, opt):
|
|
@ -1,6 +1,6 @@
|
|||
import sys
|
||||
|
||||
from models.tacotron2.stft import STFT
|
||||
from models.audio.tts.tacotron2.stft import STFT
|
||||
|
||||
sys.path.append('tacotron2')
|
||||
import torch
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum
|
||||
|
||||
from models.gpt_voice.unified_voice2 import ConditioningEncoder
|
||||
from models.audio.tts.unified_voice2 import ConditioningEncoder
|
||||
from models.lucidrains.dalle.transformer import Transformer
|
||||
from trainer.networks import register_model
|
||||
from utils.util import opt_get
|
||||
|
|
|
@ -1,137 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch import istft
|
||||
|
||||
from .unet import UNet
|
||||
from .util import tf2pytorch
|
||||
|
||||
|
||||
def load_ckpt(model, ckpt):
|
||||
state_dict = model.state_dict()
|
||||
for k, v in ckpt.items():
|
||||
if k in state_dict:
|
||||
target_shape = state_dict[k].shape
|
||||
assert target_shape == v.shape
|
||||
state_dict.update({k: torch.from_numpy(v)})
|
||||
else:
|
||||
print('Ignore ', k)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
def pad_and_partition(tensor, T):
|
||||
"""
|
||||
pads zero and partition tensor into segments of length T
|
||||
|
||||
Args:
|
||||
tensor(Tensor): BxCxFxL
|
||||
|
||||
Returns:
|
||||
tensor of size (B*[L/T] x C x F x T)
|
||||
"""
|
||||
old_size = tensor.size(3)
|
||||
new_size = math.ceil(old_size/T) * T
|
||||
tensor = F.pad(tensor, [0, new_size - old_size])
|
||||
[b, c, t, f] = tensor.shape
|
||||
split = new_size // T
|
||||
return torch.cat(torch.split(tensor, T, dim=3), dim=0)
|
||||
|
||||
|
||||
class Estimator(nn.Module):
|
||||
def __init__(self, num_instrumments, checkpoint_path):
|
||||
super(Estimator, self).__init__()
|
||||
|
||||
# stft config
|
||||
self.F = 1024
|
||||
self.T = 512
|
||||
self.win_length = 4096
|
||||
self.hop_length = 1024
|
||||
self.win = torch.hann_window(self.win_length)
|
||||
|
||||
ckpts = tf2pytorch(checkpoint_path, num_instrumments)
|
||||
|
||||
# filter
|
||||
self.instruments = nn.ModuleList()
|
||||
for i in range(num_instrumments):
|
||||
print('Loading model for instrumment {}'.format(i))
|
||||
net = UNet(2)
|
||||
ckpt = ckpts[i]
|
||||
net = load_ckpt(net, ckpt)
|
||||
net.eval() # change mode to eval
|
||||
self.instruments.append(net)
|
||||
|
||||
def compute_stft(self, wav):
|
||||
"""
|
||||
Computes stft feature from wav
|
||||
|
||||
Args:
|
||||
wav (Tensor): B x L
|
||||
"""
|
||||
|
||||
stft = torch.stft(
|
||||
wav, self.win_length, hop_length=self.hop_length, window=self.win.to(wav.device))
|
||||
|
||||
# only keep freqs smaller than self.F
|
||||
stft = stft[:, :self.F, :, :]
|
||||
real = stft[:, :, :, 0]
|
||||
im = stft[:, :, :, 1]
|
||||
mag = torch.sqrt(real ** 2 + im ** 2)
|
||||
|
||||
return stft, mag
|
||||
|
||||
def inverse_stft(self, stft):
|
||||
"""Inverses stft to wave form"""
|
||||
|
||||
pad = self.win_length // 2 + 1 - stft.size(1)
|
||||
stft = F.pad(stft, (0, 0, 0, 0, 0, pad))
|
||||
wav = istft(stft, self.win_length, hop_length=self.hop_length,
|
||||
window=self.win.to(stft.device))
|
||||
return wav.detach()
|
||||
|
||||
def separate(self, wav):
|
||||
"""
|
||||
Separates stereo wav into different tracks corresponding to different instruments
|
||||
|
||||
Args:
|
||||
wav (tensor): B x L
|
||||
"""
|
||||
|
||||
# stft - B X F x L x 2
|
||||
# stft_mag - B X F x L
|
||||
stft, stft_mag = self.compute_stft(wav)
|
||||
|
||||
L = stft.size(2)
|
||||
|
||||
stft_mag = stft_mag.unsqueeze(1).repeat(1,2,1,1) # B x 2 x F x T
|
||||
stft_mag = pad_and_partition(stft_mag, self.T) # B x 2 x F x T
|
||||
stft_mag = stft_mag.transpose(2, 3) # B x 2 x T x F
|
||||
|
||||
# compute instruments' mask
|
||||
masks = []
|
||||
for net in self.instruments:
|
||||
mask = net(stft_mag)
|
||||
masks.append(mask)
|
||||
|
||||
# compute denominator
|
||||
mask_sum = sum([m ** 2 for m in masks])
|
||||
mask_sum += 1e-10
|
||||
|
||||
wavs = []
|
||||
for mask in masks:
|
||||
mask = (mask ** 2 + 1e-10/2)/(mask_sum)
|
||||
mask = mask.transpose(2, 3) # B x 2 X F x T
|
||||
|
||||
mask = torch.cat(
|
||||
torch.split(mask, 1, dim=0), dim=3)
|
||||
|
||||
mask = mask[:,0,:,:L].unsqueeze(-1) # 2 x F x L x 1
|
||||
stft_masked = stft * mask
|
||||
wav_masked = self.inverse_stft(stft_masked)
|
||||
|
||||
wavs.append(wav_masked)
|
||||
|
||||
return wavs
|
|
@ -1,32 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.spleeter.estimator import Estimator
|
||||
|
||||
|
||||
class Separator:
|
||||
def __init__(self, model_path, input_sr=44100, device='cuda'):
|
||||
self.model = Estimator(2, model_path).to(device)
|
||||
self.device = device
|
||||
self.input_sr = input_sr
|
||||
|
||||
def separate(self, npwav, normalize=False):
|
||||
if not isinstance(npwav, torch.Tensor):
|
||||
assert len(npwav.shape) == 1
|
||||
wav = torch.tensor(npwav, device=self.device)
|
||||
wav = wav.view(1,-1)
|
||||
else:
|
||||
assert len(npwav.shape) == 2 # Input should be BxL
|
||||
wav = npwav.to(self.device)
|
||||
|
||||
if normalize:
|
||||
wav = wav / (wav.max() + 1e-8)
|
||||
|
||||
# Spleeter expects audio input to be 44.1kHz.
|
||||
wav = F.interpolate(wav.unsqueeze(1), mode='nearest', scale_factor=44100/self.input_sr).squeeze(1)
|
||||
res = self.model.separate(wav)
|
||||
res = [F.interpolate(r.unsqueeze(1), mode='nearest', scale_factor=self.input_sr/44100)[:,0] for r in res]
|
||||
return {
|
||||
'vocals': res[0].cpu().numpy(),
|
||||
'accompaniment': res[1].cpu().numpy()
|
||||
}
|
|
@ -1,80 +0,0 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def down_block(in_filters, out_filters):
|
||||
return nn.Conv2d(in_filters, out_filters, kernel_size=5,
|
||||
stride=2, padding=2,
|
||||
), nn.Sequential(
|
||||
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
|
||||
def up_block(in_filters, out_filters, dropout=False):
|
||||
layers = [
|
||||
nn.ConvTranspose2d(in_filters, out_filters, kernel_size=5,
|
||||
stride=2, padding=2, output_padding=1
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm2d(out_filters, track_running_stats=True, eps=1e-3, momentum=0.01)
|
||||
]
|
||||
if dropout:
|
||||
layers.append(nn.Dropout(0.5))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class UNet(nn.Module):
|
||||
def __init__(self, in_channels=2):
|
||||
super(UNet, self).__init__()
|
||||
self.down1_conv, self.down1_act = down_block(in_channels, 16)
|
||||
self.down2_conv, self.down2_act = down_block(16, 32)
|
||||
self.down3_conv, self.down3_act = down_block(32, 64)
|
||||
self.down4_conv, self.down4_act = down_block(64, 128)
|
||||
self.down5_conv, self.down5_act = down_block(128, 256)
|
||||
self.down6_conv, self.down6_act = down_block(256, 512)
|
||||
|
||||
self.up1 = up_block(512, 256, dropout=True)
|
||||
self.up2 = up_block(512, 128, dropout=True)
|
||||
self.up3 = up_block(256, 64, dropout=True)
|
||||
self.up4 = up_block(128, 32)
|
||||
self.up5 = up_block(64, 16)
|
||||
self.up6 = up_block(32, 1)
|
||||
self.up7 = nn.Sequential(
|
||||
nn.Conv2d(1, 2, kernel_size=4, dilation=2, padding=3),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
d1_conv = self.down1_conv(x)
|
||||
d1 = self.down1_act(d1_conv)
|
||||
|
||||
d2_conv = self.down2_conv(d1)
|
||||
d2 = self.down2_act(d2_conv)
|
||||
|
||||
d3_conv = self.down3_conv(d2)
|
||||
d3 = self.down3_act(d3_conv)
|
||||
|
||||
d4_conv = self.down4_conv(d3)
|
||||
d4 = self.down4_act(d4_conv)
|
||||
|
||||
d5_conv = self.down5_conv(d4)
|
||||
d5 = self.down5_act(d5_conv)
|
||||
|
||||
d6_conv = self.down6_conv(d5)
|
||||
d6 = self.down6_act(d6_conv)
|
||||
|
||||
u1 = self.up1(d6)
|
||||
u2 = self.up2(torch.cat([d5_conv, u1], axis=1))
|
||||
u3 = self.up3(torch.cat([d4_conv, u2], axis=1))
|
||||
u4 = self.up4(torch.cat([d3_conv, u3], axis=1))
|
||||
u5 = self.up5(torch.cat([d2_conv, u4], axis=1))
|
||||
u6 = self.up6(torch.cat([d1_conv, u5], axis=1))
|
||||
u7 = self.up7(u6)
|
||||
return u7 * x
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
net = UNet(14)
|
||||
print(net(torch.rand(1, 14, 20, 48)).shape)
|
|
@ -1,91 +0,0 @@
|
|||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from .unet import UNet
|
||||
|
||||
|
||||
def tf2pytorch(checkpoint_path, num_instrumments):
|
||||
tf_vars = {}
|
||||
init_vars = tf.train.list_variables(checkpoint_path)
|
||||
# print(init_vars)
|
||||
for name, shape in init_vars:
|
||||
try:
|
||||
# print('Loading TF Weight {} with shape {}'.format(name, shape))
|
||||
data = tf.train.load_variable(checkpoint_path, name)
|
||||
tf_vars[name] = data
|
||||
except Exception as e:
|
||||
print('Load error')
|
||||
conv_idx = 0
|
||||
tconv_idx = 0
|
||||
bn_idx = 0
|
||||
outputs = []
|
||||
for i in range(num_instrumments):
|
||||
output = {}
|
||||
outputs.append(output)
|
||||
|
||||
for j in range(1,7):
|
||||
if conv_idx == 0:
|
||||
conv_suffix = ""
|
||||
else:
|
||||
conv_suffix = "_" + str(conv_idx)
|
||||
|
||||
if bn_idx == 0:
|
||||
bn_suffix = ""
|
||||
else:
|
||||
bn_suffix = "_" + str(bn_idx)
|
||||
|
||||
output['down{}_conv.weight'.format(j)] = np.transpose(
|
||||
tf_vars["conv2d{}/kernel".format(conv_suffix)], (3, 2, 0, 1))
|
||||
# print('conv dtype: ',output['down{}.0.weight'.format(j)].dtype)
|
||||
output['down{}_conv.bias'.format(
|
||||
j)] = tf_vars["conv2d{}/bias".format(conv_suffix)]
|
||||
|
||||
output['down{}_act.0.weight'.format(
|
||||
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||||
output['down{}_act.0.bias'.format(
|
||||
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||||
output['down{}_act.0.running_mean'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||||
output['down{}_act.0.running_var'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||||
|
||||
conv_idx += 1
|
||||
bn_idx += 1
|
||||
|
||||
# up blocks
|
||||
for j in range(1, 7):
|
||||
if tconv_idx == 0:
|
||||
tconv_suffix = ""
|
||||
else:
|
||||
tconv_suffix = "_" + str(tconv_idx)
|
||||
|
||||
if bn_idx == 0:
|
||||
bn_suffix = ""
|
||||
else:
|
||||
bn_suffix= "_" + str(bn_idx)
|
||||
|
||||
output['up{}.0.weight'.format(j)] = np.transpose(
|
||||
tf_vars["conv2d_transpose{}/kernel".format(tconv_suffix)], (3,2,0, 1))
|
||||
output['up{}.0.bias'.format(
|
||||
j)] = tf_vars["conv2d_transpose{}/bias".format(tconv_suffix)]
|
||||
output['up{}.2.weight'.format(
|
||||
j)] = tf_vars["batch_normalization{}/gamma".format(bn_suffix)]
|
||||
output['up{}.2.bias'.format(
|
||||
j)] = tf_vars["batch_normalization{}/beta".format(bn_suffix)]
|
||||
output['up{}.2.running_mean'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_mean'.format(bn_suffix)]
|
||||
output['up{}.2.running_var'.format(
|
||||
j)] = tf_vars['batch_normalization{}/moving_variance'.format(bn_suffix)]
|
||||
tconv_idx += 1
|
||||
bn_idx += 1
|
||||
|
||||
if conv_idx == 0:
|
||||
suffix = ""
|
||||
else:
|
||||
suffix = "_" + str(conv_idx)
|
||||
output['up7.0.weight'] = np.transpose(
|
||||
tf_vars['conv2d{}/kernel'.format(suffix)], (3, 2, 0, 1))
|
||||
output['up7.0.bias'] = tf_vars['conv2d{}/bias'.format(suffix)]
|
||||
conv_idx += 1
|
||||
|
||||
return outputs
|
|
@ -4,20 +4,17 @@ import random
|
|||
import argparse
|
||||
|
||||
import audio2numpy
|
||||
import torchvision
|
||||
from munch import munchify
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.audio.nv_tacotron_dataset import save_mel_buffer_to_file
|
||||
from models.tacotron2 import hparams
|
||||
from models.tacotron2.layers import TacotronSTFT
|
||||
from models.tacotron2.text import sequence_to_text
|
||||
from models.audio.tts.tacotron2 import hparams
|
||||
from models.audio.tts.tacotron2 import TacotronSTFT
|
||||
from models.audio.tts.tacotron2 import sequence_to_text
|
||||
from scripts.audio.use_vocoder import Vocoder
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from data.util import is_wav_file, find_files_of_type
|
||||
from models.audio_resnet import resnet34, resnet50
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from models.audio_resnet import resnet50
|
||||
from models.audio.tts.tacotron2 import load_wav_to_torch
|
||||
from scripts.byol.byol_extract_wrapped_model import extract_byol_model_from_state_dict
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -3,12 +3,10 @@ import logging
|
|||
import random
|
||||
import argparse
|
||||
|
||||
import torchvision
|
||||
|
||||
import utils
|
||||
import utils.options as option
|
||||
import utils.util as util
|
||||
from models.tacotron2.text import sequence_to_text
|
||||
from models.audio.tts.tacotron2 import sequence_to_text
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
|
|
|
@ -2,11 +2,8 @@ import os
|
|||
import os.path as osp
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from scipy.io import wavfile
|
||||
from torchvision.transforms import ToTensor
|
||||
|
@ -15,10 +12,7 @@ import utils
|
|||
import utils.options as option
|
||||
import utils.util as util
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.tacotron2.taco_utils import load_wav_to_torch
|
||||
from trainer.ExtensibleTrainer import ExtensibleTrainer
|
||||
from data import create_dataset, create_dataloader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
|
|
@ -1,10 +1,3 @@
|
|||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from models.asr.w2v_wrapper import Wav2VecWrapper
|
||||
from models.tacotron2.text import tacotron_symbol_mapping
|
||||
|
||||
if __name__ == '__main__':
|
||||
"""
|
||||
Utility script for uploading model weights to the HF hub
|
||||
|
|
|
@ -2,12 +2,10 @@ import os
|
|||
import os.path as osp
|
||||
import torch
|
||||
import torchaudio
|
||||
import torchvision
|
||||
from pytorch_fid import fid_score
|
||||
from pytorch_fid.fid_score import calculate_frechet_distance
|
||||
from torch import distributed
|
||||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
||||
from transformers import Wav2Vec2ForCTC
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
@ -15,7 +13,7 @@ import trainer.eval.evaluator as evaluator
|
|||
from data.audio.paired_voice_audio_dataset import load_tsv_aligned_codes
|
||||
from data.audio.unsupervised_audio_dataset import load_audio
|
||||
from models.clip.mel_text_clip import MelTextCLIP
|
||||
from models.tacotron2.text import sequence_to_text, text_to_sequence
|
||||
from models.audio.tts.tacotron2 import text_to_sequence
|
||||
from scripts.audio.gen.speech_synthesis_utils import load_discrete_vocoder_diffuser, wav_to_mel, load_speech_dvae, \
|
||||
convert_mel_to_codes
|
||||
from utils.util import ceil_multiple, opt_get
|
||||
|
@ -116,6 +114,27 @@ class AudioDiffusionFid(evaluator.Evaluator):
|
|||
'unaligned_input': torch.tensor(text_codes, device=audio.device).unsqueeze(0)})
|
||||
return gen, real_resampled, sample_rate
|
||||
|
||||
|
||||
def perform_diffusion_tts9_from_codes(self, audio, codes, text, sample_rate=5500):
|
||||
mel = wav_to_mel(audio)
|
||||
mel_codes = convert_mel_to_codes(self.dvae, mel)
|
||||
text_codes = text_to_sequence(text)
|
||||
real_resampled = torchaudio.functional.resample(audio, 22050, sample_rate).unsqueeze(0)
|
||||
|
||||
output_size = real_resampled.shape[-1]
|
||||
aligned_codes_compression_factor = output_size // mel_codes.shape[-1]
|
||||
padded_size = ceil_multiple(output_size, 2048)
|
||||
padding_added = padded_size - output_size
|
||||
padding_needed_for_codes = padding_added // aligned_codes_compression_factor
|
||||
if padding_needed_for_codes > 0:
|
||||
mel_codes = F.pad(mel_codes, (0, padding_needed_for_codes))
|
||||
output_shape = (1, 1, padded_size)
|
||||
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||
model_kwargs={'tokens': mel_codes,
|
||||
'conditioning_input': audio.unsqueeze(0),
|
||||
'unaligned_input': torch.tensor(text_codes, device=audio.device).unsqueeze(0)})
|
||||
return gen, real_resampled, sample_rate
|
||||
|
||||
def load_projector(self):
|
||||
"""
|
||||
Builds the CLIP model used to project speech into a latent. This model has fixed parameters and a fixed loading
|
||||
|
|
|
@ -4,12 +4,12 @@ from datasets import load_metric
|
|||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
||||
from transformers import Wav2Vec2Processor
|
||||
|
||||
import trainer.eval.evaluator as evaluator
|
||||
from data import create_dataset, create_dataloader
|
||||
from models.asr.w2v_wrapper import only_letters, Wav2VecWrapper
|
||||
from models.tacotron2.text import sequence_to_text, tacotron_symbols
|
||||
from models.audio.asr.w2v_wrapper import only_letters, Wav2VecWrapper
|
||||
from models.audio.tts.tacotron2 import sequence_to_text, tacotron_symbols
|
||||
from pyctcdecode import build_ctcdecoder
|
||||
|
||||
# Librispeech:
|
||||
|
|
|
@ -4,7 +4,7 @@ import trainer.eval.evaluator as evaluator
|
|||
|
||||
from data import create_dataset
|
||||
from data.audio.nv_tacotron_dataset import TextMelCollate
|
||||
from models.tacotron2.loss import Tacotron2LossRaw
|
||||
from models.audio.tts.tacotron2 import Tacotron2LossRaw
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
|
@ -12,7 +11,7 @@ from utils.util import opt_get, load_model_from_config
|
|||
class MelSpectrogramInjector(Injector):
|
||||
def __init__(self, opt, env):
|
||||
super().__init__(opt, env)
|
||||
from models.tacotron2.layers import TacotronSTFT
|
||||
from models.audio.tts.tacotron2 import TacotronSTFT
|
||||
# These are the default tacotron values for the MEL spectrogram.
|
||||
filter_length = opt_get(opt, ['filter_length'], 1024)
|
||||
hop_length = opt_get(opt, ['hop_length'], 256)
|
||||
|
|
|
@ -83,7 +83,7 @@ class CombineMelInjector(Injector):
|
|||
self.text_lengths = opt['text_lengths_key']
|
||||
self.output_audio_key = opt['output_audio_key']
|
||||
self.output_text_key = opt['output_text_key']
|
||||
from models.tacotron2.text import symbols
|
||||
from models.audio.tts.tacotron2 import symbols
|
||||
self.text_separator = len(symbols)+1 # Probably need to allow this to be set by user.
|
||||
|
||||
def forward(self, state):
|
||||
|
|
|
@ -59,7 +59,7 @@ def create_loss(opt_loss, env):
|
|||
from models.switched_conv.mixture_of_experts import SwitchTransformersLoadBalancingLoss
|
||||
return SwitchTransformersLoadBalancingLoss(opt_loss, env)
|
||||
elif type == 'nv_tacotron2_loss':
|
||||
from models.tacotron2.loss import Tacotron2Loss
|
||||
from models.audio.tts.tacotron2 import Tacotron2Loss
|
||||
return Tacotron2Loss(opt_loss, env)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
Loading…
Reference in New Issue
Block a user