forked from mrq/DL-Art-School
Allow attention_dim in channel attention to be specified, add converter
This commit is contained in:
parent
9c0d7288ea
commit
f2977d360c
|
@ -9,6 +9,7 @@ from einops import rearrange
|
||||||
from torch import einsum
|
from torch import einsum
|
||||||
|
|
||||||
from models.diffusion.unet_diffusion import AttentionBlock
|
from models.diffusion.unet_diffusion import AttentionBlock
|
||||||
|
from models.gpt_voice.lucidrains_dvae import DiscreteVAE
|
||||||
from models.stylegan.stylegan2_rosinality import EqualLinear
|
from models.stylegan.stylegan2_rosinality import EqualLinear
|
||||||
from models.vqvae.vqvae import Quantize
|
from models.vqvae.vqvae import Quantize
|
||||||
from trainer.networks import register_model
|
from trainer.networks import register_model
|
||||||
|
@ -142,7 +143,7 @@ class UpsampledConv(nn.Module):
|
||||||
return self.conv(up)
|
return self.conv(up)
|
||||||
|
|
||||||
|
|
||||||
class DiscreteVAE(nn.Module):
|
class ChannelAttentionDVAE(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
positional_dims=2,
|
positional_dims=2,
|
||||||
|
@ -151,6 +152,7 @@ class DiscreteVAE(nn.Module):
|
||||||
num_layers = 3,
|
num_layers = 3,
|
||||||
num_resnet_blocks = 0,
|
num_resnet_blocks = 0,
|
||||||
hidden_dim = 64,
|
hidden_dim = 64,
|
||||||
|
channel_attention_dim = 64,
|
||||||
channels = 3,
|
channels = 3,
|
||||||
stride = 2,
|
stride = 2,
|
||||||
kernel_size = 4,
|
kernel_size = 4,
|
||||||
|
@ -218,7 +220,7 @@ class DiscreteVAE(nn.Module):
|
||||||
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1))
|
||||||
|
|
||||||
enc_layers.append(conv(enc_chans[-1], codebook_dim, 1))
|
enc_layers.append(conv(enc_chans[-1], codebook_dim, 1))
|
||||||
dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, hidden_dim, layers=3, num_heads=1))
|
dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, channel_attention_dim, layers=3, num_heads=1))
|
||||||
|
|
||||||
self.encoder = nn.Sequential(*enc_layers)
|
self.encoder = nn.Sequential(*enc_layers)
|
||||||
self.decoder = nn.Sequential(*dec_layers)
|
self.decoder = nn.Sequential(*dec_layers)
|
||||||
|
@ -320,18 +322,42 @@ class DiscreteVAE(nn.Module):
|
||||||
return recon_loss, commitment_loss, out
|
return recon_loss, commitment_loss, out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def convert_from_dvae(dvae_state_dict_file):
|
||||||
|
params = {
|
||||||
|
'channels': 80,
|
||||||
|
'positional_dims': 1,
|
||||||
|
'num_tokens': 8192,
|
||||||
|
'codebook_dim': 2048,
|
||||||
|
'hidden_dim': 512,
|
||||||
|
'stride': 2,
|
||||||
|
'num_resnet_blocks': 3,
|
||||||
|
'num_layers': 2,
|
||||||
|
'record_codes': True,
|
||||||
|
}
|
||||||
|
dvae = DiscreteVAE(**params)
|
||||||
|
dvae.load_state_dict(torch.load(dvae_state_dict_file), strict=True)
|
||||||
|
cdvae = ChannelAttentionDVAE(channel_attention_dim=256, **params)
|
||||||
|
mk, uk = cdvae.load_state_dict(dvae.state_dict(), strict=False)
|
||||||
|
for k in mk:
|
||||||
|
assert 'decoder.6' in k
|
||||||
|
for k in uk:
|
||||||
|
assert 'decoder.6' in k
|
||||||
|
cdvae.decoder[-1].bypass.load_state_dict(dvae.decoder[-1].state_dict())
|
||||||
|
torch.save(cdvae.state_dict(), 'converted_cdvae.pth')
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def register_dvae_channel_attention(opt_net, opt):
|
def register_dvae_channel_attention(opt_net, opt):
|
||||||
return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {}))
|
return ChannelAttentionDVAE(**opt_get(opt_net, ['kwargs'], {}))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
#v = DiscreteVAE()
|
convert_from_dvae('D:\\dlas\\experiments\\train_dvae_clips\\models\\20000_generator.pth')
|
||||||
#o=v(torch.randn(1,3,256,256))
|
'''
|
||||||
#print(o.shape)
|
v = ChannelAttentionDVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
|
||||||
v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096,
|
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False)
|
||||||
hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False)
|
|
||||||
#v.eval()
|
|
||||||
o=v(torch.randn(1,80,256))
|
o=v(torch.randn(1,80,256))
|
||||||
print(v.get_debug_values(0, 0))
|
print(v.get_debug_values(0, 0))
|
||||||
print(o[-1].shape)
|
print(o[-1].shape)
|
||||||
|
'''
|
||||||
|
|
Loading…
Reference in New Issue
Block a user