Allow attention_dim in channel attention to be specified, add converter

This commit is contained in:
James Betker 2021-10-05 17:29:38 -06:00
parent 9c0d7288ea
commit f2977d360c

View File

@ -9,6 +9,7 @@ from einops import rearrange
from torch import einsum
from models.diffusion.unet_diffusion import AttentionBlock
from models.gpt_voice.lucidrains_dvae import DiscreteVAE
from models.stylegan.stylegan2_rosinality import EqualLinear
from models.vqvae.vqvae import Quantize
from trainer.networks import register_model
@ -142,7 +143,7 @@ class UpsampledConv(nn.Module):
return self.conv(up)
class DiscreteVAE(nn.Module):
class ChannelAttentionDVAE(nn.Module):
def __init__(
self,
positional_dims=2,
@ -151,6 +152,7 @@ class DiscreteVAE(nn.Module):
num_layers = 3,
num_resnet_blocks = 0,
hidden_dim = 64,
channel_attention_dim = 64,
channels = 3,
stride = 2,
kernel_size = 4,
@ -218,7 +220,7 @@ class DiscreteVAE(nn.Module):
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
enc_layers.append(conv(enc_chans[-1], codebook_dim, 1))
dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, hidden_dim, layers=3, num_heads=1))
dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, channel_attention_dim, layers=3, num_heads=1))
self.encoder = nn.Sequential(*enc_layers)
self.decoder = nn.Sequential(*dec_layers)
@ -320,18 +322,42 @@ class DiscreteVAE(nn.Module):
return recon_loss, commitment_loss, out
def convert_from_dvae(dvae_state_dict_file):
params = {
'channels': 80,
'positional_dims': 1,
'num_tokens': 8192,
'codebook_dim': 2048,
'hidden_dim': 512,
'stride': 2,
'num_resnet_blocks': 3,
'num_layers': 2,
'record_codes': True,
}
dvae = DiscreteVAE(**params)
dvae.load_state_dict(torch.load(dvae_state_dict_file), strict=True)
cdvae = ChannelAttentionDVAE(channel_attention_dim=256, **params)
mk, uk = cdvae.load_state_dict(dvae.state_dict(), strict=False)
for k in mk:
assert 'decoder.6' in k
for k in uk:
assert 'decoder.6' in k
cdvae.decoder[-1].bypass.load_state_dict(dvae.decoder[-1].state_dict())
torch.save(cdvae.state_dict(), 'converted_cdvae.pth')
@register_model
def register_dvae_channel_attention(opt_net, opt):
return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {}))
return ChannelAttentionDVAE(**opt_get(opt_net, ['kwargs'], {}))
if __name__ == '__main__':
#v = DiscreteVAE()
#o=v(torch.randn(1,3,256,256))
#print(o.shape)
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False)
#v.eval()
convert_from_dvae('D:\\dlas\\experiments\\train_dvae_clips\\models\\20000_generator.pth')
'''
v = ChannelAttentionDVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False)
o=v(torch.randn(1,80,256))
print(v.get_debug_values(0, 0))
print(o[-1].shape)
'''