diff --git a/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py b/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py index e4d9266f..13b1c01d 100644 --- a/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py +++ b/codes/models/gpt_voice/dvae_arch_playground/dvae_channel_attention.py @@ -9,6 +9,7 @@ from einops import rearrange from torch import einsum from models.diffusion.unet_diffusion import AttentionBlock +from models.gpt_voice.lucidrains_dvae import DiscreteVAE from models.stylegan.stylegan2_rosinality import EqualLinear from models.vqvae.vqvae import Quantize from trainer.networks import register_model @@ -142,7 +143,7 @@ class UpsampledConv(nn.Module): return self.conv(up) -class DiscreteVAE(nn.Module): +class ChannelAttentionDVAE(nn.Module): def __init__( self, positional_dims=2, @@ -151,6 +152,7 @@ class DiscreteVAE(nn.Module): num_layers = 3, num_resnet_blocks = 0, hidden_dim = 64, + channel_attention_dim = 64, channels = 3, stride = 2, kernel_size = 4, @@ -218,7 +220,7 @@ class DiscreteVAE(nn.Module): dec_layers.insert(0, conv(codebook_dim, dec_chans[1], 1)) enc_layers.append(conv(enc_chans[-1], codebook_dim, 1)) - dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, hidden_dim, layers=3, num_heads=1)) + dec_layers.append(ChannelAttentionModule(dec_chans[-1], channels, channel_attention_dim, layers=3, num_heads=1)) self.encoder = nn.Sequential(*enc_layers) self.decoder = nn.Sequential(*dec_layers) @@ -320,18 +322,42 @@ class DiscreteVAE(nn.Module): return recon_loss, commitment_loss, out + +def convert_from_dvae(dvae_state_dict_file): + params = { + 'channels': 80, + 'positional_dims': 1, + 'num_tokens': 8192, + 'codebook_dim': 2048, + 'hidden_dim': 512, + 'stride': 2, + 'num_resnet_blocks': 3, + 'num_layers': 2, + 'record_codes': True, + } + dvae = DiscreteVAE(**params) + dvae.load_state_dict(torch.load(dvae_state_dict_file), strict=True) + cdvae = ChannelAttentionDVAE(channel_attention_dim=256, **params) + mk, uk = cdvae.load_state_dict(dvae.state_dict(), strict=False) + for k in mk: + assert 'decoder.6' in k + for k in uk: + assert 'decoder.6' in k + cdvae.decoder[-1].bypass.load_state_dict(dvae.decoder[-1].state_dict()) + torch.save(cdvae.state_dict(), 'converted_cdvae.pth') + + @register_model def register_dvae_channel_attention(opt_net, opt): - return DiscreteVAE(**opt_get(opt_net, ['kwargs'], {})) + return ChannelAttentionDVAE(**opt_get(opt_net, ['kwargs'], {})) if __name__ == '__main__': - #v = DiscreteVAE() - #o=v(torch.randn(1,3,256,256)) - #print(o.shape) - v = DiscreteVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, - hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False) - #v.eval() + convert_from_dvae('D:\\dlas\\experiments\\train_dvae_clips\\models\\20000_generator.pth') + ''' + v = ChannelAttentionDVAE(channels=80, normalization=None, positional_dims=1, num_tokens=4096, codebook_dim=4096, + hidden_dim=256, stride=2, num_resnet_blocks=2, kernel_size=3, num_layers=2, use_transposed_convs=False) o=v(torch.randn(1,80,256)) print(v.get_debug_values(0, 0)) print(o[-1].shape) + '''