Allow num_resblocks to specified per-level
This commit is contained in:
parent
83798887a8
commit
3b19581f9a
|
@ -64,15 +64,15 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_channels,
|
model_channels,
|
||||||
num_res_blocks,
|
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=2, # mean and variance
|
out_channels=2, # mean and variance
|
||||||
discrete_codes=8192,
|
discrete_codes=8192,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
# 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz
|
# 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz
|
||||||
channel_mult=(1, 1, 2, 2, 4, 8, 16, 32, 64),
|
channel_mult= (1, 1, 2, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64),
|
||||||
spectrogram_conditioning_resolutions=(4,8,16,32),
|
num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
|
||||||
attention_resolutions=(64,128,256),
|
spectrogram_conditioning_resolutions=(512,),
|
||||||
|
attention_resolutions=(512,1024,2048,4096),
|
||||||
conv_resample=True,
|
conv_resample=True,
|
||||||
dims=1,
|
dims=1,
|
||||||
use_fp16=False,
|
use_fp16=False,
|
||||||
|
@ -95,7 +95,6 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.model_channels = model_channels
|
self.model_channels = model_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
self.num_res_blocks = num_res_blocks
|
|
||||||
self.attention_resolutions = attention_resolutions
|
self.attention_resolutions = attention_resolutions
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.channel_mult = channel_mult
|
self.channel_mult = channel_mult
|
||||||
|
@ -134,12 +133,12 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
ch = model_channels
|
ch = model_channels
|
||||||
ds = 1
|
ds = 1
|
||||||
|
|
||||||
for level, mult in enumerate(channel_mult):
|
for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
|
||||||
if ds in spectrogram_conditioning_resolutions:
|
if ds in spectrogram_conditioning_resolutions:
|
||||||
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
|
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
|
||||||
ch *= 2
|
ch *= 2
|
||||||
|
|
||||||
for _ in range(num_res_blocks):
|
for _ in range(num_blocks):
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(
|
ResBlock(
|
||||||
ch,
|
ch,
|
||||||
|
@ -216,8 +215,8 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.output_blocks = nn.ModuleList([])
|
self.output_blocks = nn.ModuleList([])
|
||||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
|
||||||
for i in range(num_res_blocks + 1):
|
for i in range(num_blocks + 1):
|
||||||
ich = input_block_chans.pop()
|
ich = input_block_chans.pop()
|
||||||
layers = [
|
layers = [
|
||||||
ResBlock(
|
ResBlock(
|
||||||
|
@ -240,7 +239,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
use_new_attention_order=use_new_attention_order,
|
use_new_attention_order=use_new_attention_order,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if level and i == num_res_blocks:
|
if level and i == num_blocks:
|
||||||
out_ch = ch
|
out_ch = ch
|
||||||
layers.append(
|
layers.append(
|
||||||
ResBlock(
|
ResBlock(
|
||||||
|
@ -328,8 +327,8 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
|
||||||
# Test for ~4 second audio clip at 22050Hz
|
# Test for ~4 second audio clip at 22050Hz
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 1, 81920)
|
clip = torch.randn(2, 1, 81920)
|
||||||
spec = torch.randint(8192, (2, 500,))
|
spec = torch.randint(8192, (2, 160,))
|
||||||
cond = torch.randn(2, 4, 80, 600)
|
cond = torch.randn(2, 4, 80, 600)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556])
|
||||||
model = DiffusionVocoderWithRef(32, 2)
|
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
||||||
print(model(clip, ts, spec, cond, 4).shape)
|
print(model(clip, ts, spec, cond, 4).shape)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user