forked from mrq/DL-Art-School
more cleanup
This commit is contained in:
parent
fef1066687
commit
a5d2123daa
|
@ -146,8 +146,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
:param num_res_blocks: number of residual blocks per downsample.
|
:param num_res_blocks: number of residual blocks per downsample.
|
||||||
:param dropout: the dropout probability.
|
:param dropout: the dropout probability.
|
||||||
:param channel_mult: channel multiplier for each level of the UNet.
|
:param channel_mult: channel multiplier for each level of the UNet.
|
||||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
|
||||||
downsampling.
|
|
||||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||||
:param resblock_updown: use residual blocks for up/downsampling.
|
:param resblock_updown: use residual blocks for up/downsampling.
|
||||||
|
@ -165,39 +163,24 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
num_res_blocks=(1,1,0),
|
num_res_blocks=(1,1,0),
|
||||||
token_conditioning_resolutions=(1,4),
|
token_conditioning_resolutions=(1,4),
|
||||||
mid_resnet_depth=10,
|
mid_resnet_depth=10,
|
||||||
conv_resample=True,
|
|
||||||
dims=1,
|
dims=1,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
kernel_size=3,
|
|
||||||
scale_factor=2,
|
|
||||||
time_embed_dim_multiplier=1,
|
time_embed_dim_multiplier=1,
|
||||||
freeze_main_net=False,
|
|
||||||
use_scale_shift_norm=True,
|
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
unconditioned_percentage=.1, # This implements a mechanism similar to what is used in classifier-free training.
|
||||||
# Parameters for super-sampling.
|
|
||||||
super_sampling=False,
|
|
||||||
super_sampling_max_noising_factor=.1,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if super_sampling:
|
|
||||||
in_channels *= 2 # In super-sampling mode, the LR input is concatenated directly onto the input.
|
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
self.conv_resample = conv_resample
|
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
self.super_sampling_enabled = super_sampling
|
|
||||||
self.super_sampling_max_noising_factor = super_sampling_max_noising_factor
|
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
self.alignment_size = 2 ** (len(channel_mult)+1)
|
self.alignment_size = 2 ** (len(channel_mult)+1)
|
||||||
self.freeze_main_net = freeze_main_net
|
|
||||||
self.in_mel_channels = in_mel_channels
|
self.in_mel_channels = in_mel_channels
|
||||||
padding = 1 if kernel_size == 3 else 2
|
|
||||||
|
|
||||||
time_embed_dim = model_channels * time_embed_dim_multiplier
|
time_embed_dim = model_channels * time_embed_dim_multiplier
|
||||||
self.time_embed = nn.Sequential(
|
self.time_embed = nn.Sequential(
|
||||||
|
@ -217,7 +200,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, kernel_size, padding=padding)
|
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -242,8 +225,8 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
dropout,
|
dropout,
|
||||||
out_channels=int(mult * model_channels),
|
out_channels=int(mult * model_channels),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=3,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = int(mult * model_channels)
|
ch = int(mult * model_channels)
|
||||||
|
@ -255,7 +238,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.input_blocks.append(
|
self.input_blocks.append(
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
Downsample(
|
Downsample(
|
||||||
ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor, ksize=3, pad=1
|
ch, True, dims=dims, out_channels=out_ch, factor=2, ksize=3, pad=1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -279,15 +262,15 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
dropout,
|
dropout,
|
||||||
out_channels=int(model_channels * mult),
|
out_channels=int(model_channels * mult),
|
||||||
dims=dims,
|
dims=dims,
|
||||||
kernel_size=kernel_size,
|
kernel_size=3,
|
||||||
use_scale_shift_norm=use_scale_shift_norm,
|
use_scale_shift_norm=True,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
ch = int(model_channels * mult)
|
ch = int(model_channels * mult)
|
||||||
if level and i == num_blocks:
|
if level and i == num_blocks:
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(
|
layers.append(
|
||||||
Upsample(ch, conv_resample, dims=dims, out_channels=out_ch, factor=scale_factor)
|
Upsample(ch, True, dims=dims, out_channels=out_ch, factor=2)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
@ -296,20 +279,10 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, kernel_size, padding=padding)),
|
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.freeze_main_net:
|
|
||||||
mains = [self.time_embed, self.contextual_embedder, self.unconditioned_embedding, self.conditioning_timestep_integrator,
|
|
||||||
self.input_blocks, self.middle_block, self.output_blocks, self.out]
|
|
||||||
for m in mains:
|
|
||||||
for p in m.parameters():
|
|
||||||
p.requires_grad = False
|
|
||||||
p.DO_NOT_TRAIN = True
|
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
if self.freeze_main_net:
|
|
||||||
return {}
|
|
||||||
groups = {
|
groups = {
|
||||||
'input_blocks': list(self.input_blocks.parameters()),
|
'input_blocks': list(self.input_blocks.parameters()),
|
||||||
'output_blocks': list(self.output_blocks.parameters()),
|
'output_blocks': list(self.output_blocks.parameters()),
|
||||||
|
|
|
@ -61,6 +61,7 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
|
|
||||||
if mode == 'spec_decode':
|
if mode == 'spec_decode':
|
||||||
self.diffusion_fn = self.perform_diffusion_spec_decode
|
self.diffusion_fn = self.perform_diffusion_spec_decode
|
||||||
|
self.squeeze_ratio = opt_eval['squeeze_ratio']
|
||||||
elif 'from_codes' == mode:
|
elif 'from_codes' == mode:
|
||||||
self.diffusion_fn = self.perform_diffusion_from_codes
|
self.diffusion_fn = self.perform_diffusion_from_codes
|
||||||
self.local_modules['codegen'] = get_music_codegen()
|
self.local_modules['codegen'] = get_music_codegen()
|
||||||
|
@ -81,11 +82,11 @@ class MusicDiffusionFid(evaluator.Evaluator):
|
||||||
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
|
def perform_diffusion_spec_decode(self, audio, sample_rate=22050):
|
||||||
real_resampled = audio
|
real_resampled = audio
|
||||||
audio = audio.unsqueeze(0)
|
audio = audio.unsqueeze(0)
|
||||||
output_shape = (1, 256, audio.shape[-1] // 256)
|
output_shape = (1, self.squeeze_ratio, audio.shape[-1] // self.squeeze_ratio)
|
||||||
mel = self.spec_fn({'in': audio})['out']
|
mel = self.spec_fn({'in': audio})['out']
|
||||||
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
gen = self.diffuser.p_sample_loop(self.model, output_shape,
|
||||||
model_kwargs={'codes': mel})
|
model_kwargs={'codes': mel})
|
||||||
gen = pixel_shuffle_1d(gen, 256)
|
gen = pixel_shuffle_1d(gen, self.squeeze_ratio)
|
||||||
|
|
||||||
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
|
return gen, real_resampled, normalize_mel(self.spec_fn({'in': gen})['out']), normalize_mel(mel), sample_rate
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user