Allow num_resblocks to specified per-level

This commit is contained in:
James Betker 2021-10-14 11:26:04 -06:00
parent 83798887a8
commit 3b19581f9a

View File

@ -64,15 +64,15 @@ class DiffusionVocoderWithRef(nn.Module):
def __init__(
self,
model_channels,
num_res_blocks,
in_channels=1,
out_channels=2, # mean and variance
discrete_codes=8192,
dropout=0,
# 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz
channel_mult=(1, 1, 2, 2, 4, 8, 16, 32, 64),
spectrogram_conditioning_resolutions=(4,8,16,32),
attention_resolutions=(64,128,256),
channel_mult= (1, 1, 2, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64),
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
spectrogram_conditioning_resolutions=(512,),
attention_resolutions=(512,1024,2048,4096),
conv_resample=True,
dims=1,
use_fp16=False,
@ -95,7 +95,6 @@ class DiffusionVocoderWithRef(nn.Module):
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
@ -134,12 +133,12 @@ class DiffusionVocoderWithRef(nn.Module):
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions:
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
ch *= 2
for _ in range(num_res_blocks):
for _ in range(num_blocks):
layers = [
ResBlock(
ch,
@ -216,8 +215,8 @@ class DiffusionVocoderWithRef(nn.Module):
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(num_res_blocks + 1):
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
for i in range(num_blocks + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
@ -240,7 +239,7 @@ class DiffusionVocoderWithRef(nn.Module):
use_new_attention_order=use_new_attention_order,
)
)
if level and i == num_res_blocks:
if level and i == num_blocks:
out_ch = ch
layers.append(
ResBlock(
@ -328,8 +327,8 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
# Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__':
clip = torch.randn(2, 1, 81920)
spec = torch.randint(8192, (2, 500,))
spec = torch.randint(8192, (2, 160,))
cond = torch.randn(2, 4, 80, 600)
ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, 2)
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
print(model(clip, ts, spec, cond, 4).shape)