Allow attention_dim in channel attention to be specified, add converter
This commit is contained in:
parent
9c0d7288ea
commit
f2977d360c
|
@ -9,6 +9,7 @@ from einops import rearrange
|
|||
from torch import einsum
|
||||
|
||||
from models.diffusion.unet_diffusion import AttentionBlock
|
||||
from models.gpt_voice.lucidrains_dvae import DiscreteVAE
|
||||
from models.stylegan.stylegan2_rosinality import EqualLinear
|
||||
from models.vqvae.vqvae import Quantize
|
||||
from trainer.networks import register_model
|
||||
|
@ -142,7 +143,7 @@ class UpsampledConv(nn.Module):
|
|||
return self.conv(up)
|
||||
|
||||
|
||||
class DiscreteVAE(nn.Module):
|
||||
class ChannelAttentionDVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
positional_dims=2,
|
||||
|
@ -151,6 +152,7 @@ class DiscreteVAE(nn.Module):
|
|||
num_layers = 3,
|
||||
num_resnet_blocks = 0,
|
||||
hidden_dim = 64,
|
||||
channel_attention_dim = 64,
|
||||
channels = 3,
|
||||
stride = 2,
|
||||
kernel_size = 4,
|
||||
|
@ -218,7 +220,7 @@ class DiscreteVAE(nn.Module):
|
|||
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
||||
|
||||
enc_layers.append(conv(enc_chans[-1], codebook_dim, 1))
|
||||
dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, hidden_dim, layers=3, num_heads=1))
|
||||
dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, channel_attention_dim, layers=3, num_heads=1))
|
||||
|
||||
self.encoder = nn.Sequential(*enc_layers)
|
||||
self.decoder = nn.Sequential(*dec_layers)
|
||||
|
@ -320,18 +322,42 @@ class DiscreteVAE(nn.Module):
|
|||
return recon_loss, commitment_loss, out
|
||||
|
||||
|
||||
|
||||
def convert_from_dvae(dvae_state_dict_file):
|
||||
params = {
|
||||
'channels': 80,
|
||||
'positional_dims': 1,
|
||||
'num_tokens': 8192,
|
||||
'codebook_dim': 2048,
|
||||
'hidden_dim': 512,
|
||||
'stride': 2,
|
||||
'num_resnet_blocks': 3,
|
||||
'num_layers': 2,
|
||||
'record_codes': True,
|
||||
}
|
||||
dvae = DiscreteVAE(**params)
|
||||
dvae.load_state_dict(torch.load(dvae_state_dict_file), strict=True)
|
||||
cdvae = ChannelAttentionDVAE(channel_attention_dim=256, **params)
|
||||
mk, uk = cdvae.load_state_dict(dvae.state_dict(), strict=False)
|
||||
for k in mk:
|
||||
assert 'decoder.6' in k
|
||||
for k in uk:
|
||||
assert 'decoder.6' in k
|
||||
cdvae.decoder[-1].bypass.load_state_dict(dvae.decoder[-1].state_dict())
|
||||
torch.save(cdvae.state_dict(), 'converted_cdvae.pth')
|
||||
|
||||
|
||||
@register_model
|
||||
def register_dvae_channel_attention(opt_net, opt):
|
||||
return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {}))
|
||||
return ChannelAttentionDVAE(**opt_get(opt_net, ['kwargs'], {}))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
#v = DiscreteVAE()
|
||||
#o=v(torch.randn(1,3,256,256))
|
||||
#print(o.shape)
|
||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
|
||||
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False)
|
||||
#v.eval()
|
||||
convert_from_dvae('D:\\dlas\\experiments\\train_dvae_clips\\models\\20000_generator.pth')
|
||||
'''
|
||||
v = ChannelAttentionDVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
|
||||
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False)
|
||||
o=v(torch.randn(1,80,256))
|
||||
print(v.get_debug_values(0, 0))
|
||||
print(o[-1].shape)
|
||||
'''
|
||||
|
|
Loading…
Reference in New Issue
Block a user