Allow num_resblocks to specified per-level
This commit is contained in:
parent
83798887a8
commit
3b19581f9a
|
@ -64,15 +64,15 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
def __init__(
|
||||
self,
|
||||
model_channels,
|
||||
num_res_blocks,
|
||||
in_channels=1,
|
||||
out_channels=2, # mean and variance
|
||||
discrete_codes=8192,
|
||||
dropout=0,
|
||||
# 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz
|
||||
channel_mult=(1, 1, 2, 2, 4, 8, 16, 32, 64),
|
||||
spectrogram_conditioning_resolutions=(4,8,16,32),
|
||||
attention_resolutions=(64,128,256),
|
||||
channel_mult= (1, 1, 2, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64),
|
||||
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||
spectrogram_conditioning_resolutions=(512,),
|
||||
attention_resolutions=(512,1024,2048,4096),
|
||||
conv_resample=True,
|
||||
dims=1,
|
||||
use_fp16=False,
|
||||
|
@ -95,7 +95,6 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
|
@ -134,12 +133,12 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
ch = model_channels
|
||||
ds = 1
|
||||
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
||||
if ds in spectrogram_conditioning_resolutions:
|
||||
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
|
||||
ch *= 2
|
||||
|
||||
for _ in range(num_res_blocks):
|
||||
for _ in range(num_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
|
@ -216,8 +215,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(num_res_blocks + 1):
|
||||
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
|
||||
for i in range(num_blocks + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
|
@ -240,7 +239,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
|||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
if level and i == num_res_blocks:
|
||||
if level and i == num_blocks:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
|
@ -328,8 +327,8 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
|
|||
# Test for ~4 second audio clip at 22050Hz
|
||||
if __name__ == '__main__':
|
||||
clip = torch.randn(2, 1, 81920)
|
||||
spec = torch.randint(8192, (2, 500,))
|
||||
spec = torch.randint(8192, (2, 160,))
|
||||
cond = torch.randn(2, 4, 80, 600)
|
||||
ts = torch.LongTensor([555, 556])
|
||||
model = DiffusionVocoderWithRef(32, 2)
|
||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
||||
print(model(clip, ts, spec, cond, 4).shape)
|
||||
|
|
Loading…
Reference in New Issue
Block a user