Throw out the idea of conditioning on discrete codes. Oh well :(
This commit is contained in:
parent
62c8c5d93e
commit
0edc98f6c4
|
@ -190,8 +190,10 @@ class DiscreteVAE(nn.Module):
|
||||||
arrange = 'b (h w) d -> b d h w'
|
arrange = 'b (h w) d -> b d h w'
|
||||||
kwargs = {'h': h, 'w': w}
|
kwargs = {'h': h, 'w': w}
|
||||||
image_embeds = rearrange(image_embeds, arrange, **kwargs)
|
image_embeds = rearrange(image_embeds, arrange, **kwargs)
|
||||||
images = self.decoder(image_embeds)
|
images = [image_embeds]
|
||||||
return images
|
for layer in self.decoder:
|
||||||
|
images.append(layer(images[-1]))
|
||||||
|
return images[-1], images[-2]
|
||||||
|
|
||||||
def infer(self, img):
|
def infer(self, img):
|
||||||
img = self.norm(img)
|
img = self.norm(img)
|
||||||
|
|
|
@ -11,9 +11,9 @@ from utils.util import get_mask_from_lengths
|
||||||
|
|
||||||
|
|
||||||
class DiscreteSpectrogramConditioningBlock(nn.Module):
|
class DiscreteSpectrogramConditioningBlock(nn.Module):
|
||||||
def __init__(self, discrete_codes, channels):
|
def __init__(self, dvae_channels, channels):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.emb = nn.Embedding(discrete_codes, channels)
|
self.emb = nn.Conv1d(dvae_channels, channels, kernel_size=1)
|
||||||
self.norm = normalization(channels)
|
self.norm = normalization(channels)
|
||||||
self.act = nn.SiLU()
|
self.act = nn.SiLU()
|
||||||
self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1),
|
self.intg = nn.Sequential(nn.Conv1d(channels*2, channels*2, kernel_size=1),
|
||||||
|
@ -30,11 +30,10 @@ class DiscreteSpectrogramConditioningBlock(nn.Module):
|
||||||
:param x: bxcxS waveform latent
|
:param x: bxcxS waveform latent
|
||||||
:param codes: bxN discrete codes, N <= S
|
:param codes: bxN discrete codes, N <= S
|
||||||
"""
|
"""
|
||||||
def forward(self, x, codes):
|
def forward(self, x, dvae_in):
|
||||||
_, c, S = x.shape
|
b, c, S = x.shape
|
||||||
b, N = codes.shape
|
_, q, N = dvae_in.shape
|
||||||
assert N <= S
|
emb = self.emb(dvae_in)
|
||||||
emb = self.emb(codes).permute(0,2,1)
|
|
||||||
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
|
emb = nn.functional.interpolate(emb, size=(S,), mode='nearest')
|
||||||
together = torch.cat([self.act(self.norm(x)), emb], dim=1)
|
together = torch.cat([self.act(self.norm(x)), emb], dim=1)
|
||||||
together = self.intg(together)
|
together = self.intg(together)
|
||||||
|
@ -77,7 +76,7 @@ class DiffusionVocoderWithRef(nn.Module):
|
||||||
model_channels,
|
model_channels,
|
||||||
in_channels=1,
|
in_channels=1,
|
||||||
out_channels=2, # mean and variance
|
out_channels=2, # mean and variance
|
||||||
discrete_codes=8192,
|
discrete_codes=512,
|
||||||
dropout=0,
|
dropout=0,
|
||||||
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
# res 1, 2, 4, 8,16,32,64,128,256,512, 1K, 2K
|
||||||
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
channel_mult= (1,1.5,2, 3, 4, 6, 8, 12, 16, 24, 32, 48),
|
||||||
|
@ -339,7 +338,8 @@ def register_unet_diffusion_vocoder_with_ref(opt_net, opt):
|
||||||
# Test for ~4 second audio clip at 22050Hz
|
# Test for ~4 second audio clip at 22050Hz
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
clip = torch.randn(2, 1, 40960)
|
clip = torch.randn(2, 1, 40960)
|
||||||
spec = torch.randint(8192, (2, 40,))
|
#spec = torch.randint(8192, (2, 40,))
|
||||||
|
spec = torch.randn(8,512,160)
|
||||||
cond = torch.randn(2, 3, 80, 173)
|
cond = torch.randn(2, 3, 80, 173)
|
||||||
ts = torch.LongTensor([555, 556])
|
ts = torch.LongTensor([555, 556])
|
||||||
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
model = DiffusionVocoderWithRef(32, conditioning_inputs_provided=False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user