forked from mrq/DL-Art-School
Update SR model
This commit is contained in:
parent
de1a1d501a
commit
7f4fc55344
|
@ -182,8 +182,7 @@ class DiffusionTts(nn.Module):
|
||||||
mid_transformer_depth=8,
|
mid_transformer_depth=8,
|
||||||
nil_guidance_fwd_proportion=.3,
|
nil_guidance_fwd_proportion=.3,
|
||||||
super_sampling=False,
|
super_sampling=False,
|
||||||
max_positions=-1,
|
super_sampling_max_noising_factor=.1,
|
||||||
fully_disable_tokens_percent=0, # When specified, this percent of the time tokens are entirely ignored.
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -207,8 +206,7 @@ class DiffusionTts(nn.Module):
|
||||||
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
|
self.nil_guidance_fwd_proportion = nil_guidance_fwd_proportion
|
||||||
self.mask_token_id = num_tokens
|
self.mask_token_id = num_tokens
|
||||||
self.super_sampling_enabled = super_sampling
|
self.super_sampling_enabled = super_sampling
|
||||||
self.max_positions = max_positions
|
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
||||||
self.fully_disable_tokens_percent = fully_disable_tokens_percent
|
|
||||||
padding = 1 if kernel_size == 3 else 2
|
padding = 1 if kernel_size == 3 else 2
|
||||||
|
|
||||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
|
@ -398,17 +396,12 @@ class DiffusionTts(nn.Module):
|
||||||
assert conditioning_input is not None
|
assert conditioning_input is not None
|
||||||
if self.super_sampling_enabled:
|
if self.super_sampling_enabled:
|
||||||
assert lr_input is not None
|
assert lr_input is not None
|
||||||
|
if self.super_sampling_max_noising_factor > 0:
|
||||||
|
noising_factor = random.uniform(0,self.super_sampling_max_noising_factor)
|
||||||
|
lr_input = torch.rand_like(lr_input) * noising_factor + lr_input
|
||||||
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
lr_input = F.interpolate(lr_input, size=(x.shape[-1],), mode='nearest')
|
||||||
x = torch.cat([x, lr_input], dim=1)
|
x = torch.cat([x, lr_input], dim=1)
|
||||||
|
|
||||||
if tokens is not None and self.fully_disable_tokens_percent > random.random():
|
|
||||||
tokens = None
|
|
||||||
|
|
||||||
if tokens is not None and self.max_positions > 0 and x.shape[-1] > self.max_positions:
|
|
||||||
proportion_x_removed = self.max_positions/x.shape[-1]
|
|
||||||
x = x[:,:,:self.max_positions] # TODO: extract random subsets of x (favored towards the front). This should help diversity in training.
|
|
||||||
tokens = tokens[:,:int(proportion_x_removed*tokens.shape[-1])]
|
|
||||||
|
|
||||||
with autocast(x.device.type):
|
with autocast(x.device.type):
|
||||||
orig_x_shape = x.shape[-1]
|
orig_x_shape = x.shape[-1]
|
||||||
cm = ceil_multiple(x.shape[-1], 2048)
|
cm = ceil_multiple(x.shape[-1], 2048)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from data.audio.unsupervised_audio_dataset import load_audio
|
||||||
from data.util import find_files_of_type, is_audio_file
|
from data.util import find_files_of_type, is_audio_file
|
||||||
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
from models.diffusion.gaussian_diffusion import get_named_beta_schedule
|
||||||
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
from models.diffusion.respace import SpacedDiffusion, space_timesteps
|
||||||
from trainer.injectors.base_injectors import TorchMelSpectrogramInjector
|
from trainer.injectors.audio_injectors import TorchMelSpectrogramInjector
|
||||||
from utils.audio import plot_spectrogram
|
from utils.audio import plot_spectrogram
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -299,7 +299,7 @@ class Trainer:
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_encoder_build_ctc_alignments.yml')
|
parser.add_argument('-opt', type=str, help='Path to option YAML file.', default='../options/train_diffusion_tts6_upsample.yml')
|
||||||
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', help='job launcher')
|
||||||
parser.add_argument('--local_rank', type=int, default=0)
|
parser.add_argument('--local_rank', type=int, default=0)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user