Allow num_resblocks to specified per-level

This commit is contained in:
James Betker 2021-10-14 11:26:04 -06:00
parent 83798887a8
commit 3b19581f9a

View File

@ -64,15 +64,15 @@ class DiffusionVocoderWithRef(nn.Module):
def __init__( def __init__(
self, self,
model_channels, model_channels,
num_res_blocks,
in_channels=1, in_channels=1,
out_channels=2, # mean and variance out_channels=2, # mean and variance
discrete_codes=8192, discrete_codes=8192,
dropout=0, dropout=0,
# 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz # 38400 -> 19200 -> 9600 -> 4800 -> 2400 -> 1200 -> 600 -> 300 -> 150 for ~2secs@22050Hz
channel_mult=(1, 1, 2, 2, 4, 8, 16, 32, 64), channel_mult= (1, 1, 2, 2, 4, 6, 8, 12, 16, 24, 32, 48, 64),
spectrogram_conditioning_resolutions=(4,8,16,32), num_res_blocks=(1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2),
attention_resolutions=(64,128,256), spectrogram_conditioning_resolutions=(512,),
attention_resolutions=(512,1024,2048,4096),
conv_resample=True, conv_resample=True,
dims=1, dims=1,
use_fp16=False, use_fp16=False,
@ -95,7 +95,6 @@ class DiffusionVocoderWithRef(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.model_channels = model_channels self.model_channels = model_channels
self.out_channels = out_channels self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions self.attention_resolutions = attention_resolutions
self.dropout = dropout self.dropout = dropout
self.channel_mult = channel_mult self.channel_mult = channel_mult
@ -134,12 +133,12 @@ class DiffusionVocoderWithRef(nn.Module):
ch = model_channels ch = model_channels
ds = 1 ds = 1
for level, mult in enumerate(channel_mult): for level, (mult, num_blocks) in enumerate(zip(channel_mult, num_res_blocks)):
if ds in spectrogram_conditioning_resolutions: if ds in spectrogram_conditioning_resolutions:
self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch)) self.input_blocks.append(DiscreteSpectrogramConditioningBlock(discrete_codes, ch))
ch *= 2 ch *= 2
for _ in range(num_res_blocks): for _ in range(num_blocks):
layers = [ layers = [
ResBlock( ResBlock(
ch, ch,
@ -216,8 +215,8 @@ class DiffusionVocoderWithRef(nn.Module):
self._feature_size += ch self._feature_size += ch
self.output_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]: for level, (mult, num_blocks) in list(enumerate(zip(channel_mult, num_res_blocks)))[::-1]:
for i in range(num_res_blocks + 1): for i in range(num_blocks + 1):
ich = input_block_chans.pop() ich = input_block_chans.pop()
layers = [ layers = [
ResBlock( ResBlock(
@ -240,7 +239,7 @@ class DiffusionVocoderWithRef(nn.Module):
use_new_attention_order=use_new_attention_order, use_new_attention_order=use_new_attention_order,
) )
) )
if level and i == num_res_blocks: if level and i == num_blocks:
out_ch = ch out_ch = ch
layers.append( layers.append(
ResBlock( ResBlock(
@ -328,8 +327,8 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
# Test for ~4 second audio clip at 22050Hz # Test for ~4 second audio clip at 22050Hz
if __name__ == '__main__': if __name__ == '__main__':
clip = torch.randn(2, 1, 81920) clip = torch.randn(2, 1, 81920)
spec = torch.randint(8192, (2, 500,)) spec = torch.randint(8192, (2, 160,))
cond = torch.randn(2, 4, 80, 600) cond = torch.randn(2, 4, 80, 600)
ts = torch.LongTensor([555, 556]) ts = torch.LongTensor([555, 556])
model = DiffusionVocoderWithRef(32, 2) model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
print(model(clip, ts, spec, cond, 4).shape) print(model(clip, ts, spec, cond, 4).shape)