forked from mrq/DL-Art-School
More mods
This commit is contained in:
parent
691ed196da
commit
c5ea2bee52
|
@ -150,8 +150,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
:param dropout: the dropout probability.
|
:param dropout: the dropout probability.
|
||||||
:param channel_mult: channel multiplier for each level of the UNet.
|
:param channel_mult: channel multiplier for each level of the UNet.
|
||||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
|
||||||
:param resblock_updown: use residual blocks for up/downsampling.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -166,7 +164,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
num_res_blocks=(1,1,0),
|
num_res_blocks=(1,1,0),
|
||||||
token_conditioning_resolutions=(1,4),
|
token_conditioning_resolutions=(1,4),
|
||||||
mid_resnet_depth=10,
|
mid_resnet_depth=10,
|
||||||
dims=1,
|
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
time_embed_dim_multiplier=1,
|
time_embed_dim_multiplier=1,
|
||||||
# Parameters for regularization.
|
# Parameters for regularization.
|
||||||
|
@ -179,7 +176,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
self.dims = dims
|
|
||||||
self.unconditioned_percentage = unconditioned_percentage
|
self.unconditioned_percentage = unconditioned_percentage
|
||||||
self.enable_fp16 = use_fp16
|
self.enable_fp16 = use_fp16
|
||||||
self.alignment_size = 2 ** (len(channel_mult)+1)
|
self.alignment_size = 2 ** (len(channel_mult)+1)
|
||||||
|
@ -203,7 +199,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.input_blocks = nn.ModuleList(
|
self.input_blocks = nn.ModuleList(
|
||||||
[
|
[
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
conv_nd(1, in_channels, model_channels, 3, padding=1)
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -227,7 +223,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
dropout,
|
dropout,
|
||||||
out_channels=int(mult * model_channels),
|
out_channels=int(mult * model_channels),
|
||||||
dims=dims,
|
dims=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
use_scale_shift_norm=True,
|
use_scale_shift_norm=True,
|
||||||
)
|
)
|
||||||
|
@ -241,7 +237,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.input_blocks.append(
|
self.input_blocks.append(
|
||||||
TimestepEmbedSequential(
|
TimestepEmbedSequential(
|
||||||
Downsample(
|
Downsample(
|
||||||
ch, True, dims=dims, out_channels=out_ch, factor=2, ksize=3, pad=1
|
ch, True, dims=1, out_channels=out_ch, factor=2, ksize=3, pad=1
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -264,7 +260,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
time_embed_dim,
|
time_embed_dim,
|
||||||
dropout,
|
dropout,
|
||||||
out_channels=int(model_channels * mult),
|
out_channels=int(model_channels * mult),
|
||||||
dims=dims,
|
dims=1,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
use_scale_shift_norm=True,
|
use_scale_shift_norm=True,
|
||||||
)
|
)
|
||||||
|
@ -273,7 +269,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
if level and i == num_blocks:
|
if level and i == num_blocks:
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(
|
layers.append(
|
||||||
Upsample(ch, True, dims=dims, out_channels=out_ch, factor=2)
|
Upsample(ch, True, dims=1, out_channels=out_ch, factor=2)
|
||||||
)
|
)
|
||||||
ds //= 2
|
ds //= 2
|
||||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||||
|
@ -282,7 +278,7 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
normalization(ch),
|
normalization(ch),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
zero_module(conv_nd(1, model_channels, out_channels, 3, padding=1)),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_grad_norm_parameter_groups(self):
|
def get_grad_norm_parameter_groups(self):
|
||||||
|
@ -355,13 +351,6 @@ class DiffusionWaveformGen(nn.Module):
|
||||||
|
|
||||||
return out[:, :, :orig_x_shape]
|
return out[:, :, :orig_x_shape]
|
||||||
|
|
||||||
def before_step(self, step):
|
|
||||||
# The middle block traditionally gets really small gradients; scale them up by an order of magnitude.
|
|
||||||
scaled_grad_parameters = self.middle_block.parameters()
|
|
||||||
for p in scaled_grad_parameters:
|
|
||||||
if hasattr(p, 'grad') and p.grad is not None:
|
|
||||||
p.grad *= 10
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_unet_diffusion_waveform_gen3(opt_net, opt):
|
def register_unet_diffusion_waveform_gen3(opt_net, opt):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user