Update SR model
This commit is contained in:
parent
de1a1d501a
commit
7f4fc55344
|
@ -182,8 +182,7 @@ class DiffusionTts(nn.Module):
|
|||
mid_transformer_depth=8,
|
||||
nil_guidance_fwd_proportion=.3,
|
||||
super_sampling=False,
|
||||
max_positions=-1,
|
||||
fully_disable_tokens_percent=0, # When specified, this percent of the time tokens are entirely ignored.
|
||||
super_sampling_max_noising_factor=.1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -207,8 +206,7 @@ class DiffusionTts(nn.Module):
|
|||
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
|
||||
self.mask_token_id = num_tokens
|
||||
self.super_sampling_enabled = super_sampling
|
||||
self.max_positions = max_positions
|
||||
self.fully_disable_tokens_percent = fully_disable_tokens_percent
|
||||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||
padding = 1 if kernel_size == 3 else 2
|
||||
|
||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||
|
@ -398,17 +396,12 @@ class DiffusionTts(nn.Module):
|
|||
assert conditioning_input is not None
|
||||
if self.super_sampling_enabled:
|
||||
assert lr_input is not None
|
||||
if self.super_sampling_max_noising_factor > 0:
|
||||
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
|
||||
lr_input = torch.rand_like(lr_input) * noising_factor + lr_input
|
||||
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||
x = torch.cat([x, lr_input], dim=1)
|
||||
|
||||
if tokens is not None and self.fully_disable_tokens_percent > random.random():
|
||||
tokens = None
|
||||
|
||||
if tokens is not None and self.max_positions > 0 and x.shape[-1] > self.max_positions:
|
||||
proportion_x_removed = self.max_positions/x.shape[-1]
|
||||
x = x[:,:,:self.max_positions] # TODO: extract random subsets of x (favored towards the front). This should help diversity in training.
|
||||
tokens = tokens[:,:int(proportion_x_removed*tokens.shape[-1])]
|
||||
|
||||
with autocast(x.device.type):
|
||||
orig_x_shape = x.shape[-1]
|
||||
cm = ceil_multiple(x.shape[-1], 2048)
|
||||
|
|
|
@ -7,7 +7,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
|
|||
from data.util import find_files_of_type, is_audio_file
|
||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||
from trainer.injectors.base_injectors import TorchMelSpectrogramInjector
|
||||
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||
from utils.audio import plot_spectrogram
|
||||
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ class Trainer:
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts6_upsample.yml')
|
||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||
parser.add_argument('--local_rank', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue
Block a user