Update SR model

This commit is contained in:
James Betker 2022-02-03 21:42:53 -07:00
parent de1a1d501a
commit 7f4fc55344
3 changed files with 7 additions and 14 deletions

View File

@ -182,8 +182,7 @@ class DiffusionTts(nn.Module):
mid_transformer_depth=8,
nil_guidance_fwd_proportion=.3,
super_sampling=False,
max_positions=-1,
fully_disable_tokens_percent=0, # When specified, this percent of the time tokens are entirely ignored.
super_sampling_max_noising_factor=.1,
):
super().__init__()
@ -207,8 +206,7 @@ class DiffusionTts(nn.Module):
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
self.mask_token_id = num_tokens
self.super_sampling_enabled = super_sampling
self.max_positions = max_positions
self.fully_disable_tokens_percent = fully_disable_tokens_percent
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
padding = 1 if kernel_size == 3 else 2
time_embed_dim = model_channels * time_embed_dim_multiplier
@ -398,17 +396,12 @@ class DiffusionTts(nn.Module):
assert conditioning_input is not None
if self.super_sampling_enabled:
assert lr_input is not None
if self.super_sampling_max_noising_factor > 0:
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
lr_input = torch.rand_like(lr_input) * noising_factor + lr_input
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
x = torch.cat([x, lr_input], dim=1)
if tokens is not None and self.fully_disable_tokens_percent > random.random():
tokens = None
if tokens is not None and self.max_positions > 0 and x.shape[-1] > self.max_positions:
proportion_x_removed = self.max_positions/x.shape[-1]
x = x[:,:,:self.max_positions] # TODO: extract random subsets of x (favored towards the front). This should help diversity in training.
tokens = tokens[:,:int(proportion_x_removed*tokens.shape[-1])]
with autocast(x.device.type):
orig_x_shape = x.shape[-1]
cm = ceil_multiple(x.shape[-1], 2048)

View File

@ -7,7 +7,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
from data.util import find_files_of_type, is_audio_file
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
from models.diffusion.respace import SpacedDiffusion, space_timesteps
from trainer.injectors.base_injectors import TorchMelSpectrogramInjector
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
from utils.audio import plot_spectrogram

View File

@ -299,7 +299,7 @@ class Trainer:
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts6_upsample.yml')
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
args = parser.parse_args()