From dee2b727865bcc8f95d6ad266819000e615e39a4 Mon Sep 17 00:00:00 2001 From: James Betker Date: Wed, 8 Jun 2022 11:53:10 -0600 Subject: [PATCH] checkpointing bugs, smh --- codes/models/audio/music/gpt_music.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/codes/models/audio/music/gpt_music.py b/codes/models/audio/music/gpt_music.py index d8b6c287..ac05f1dd 100644 --- a/codes/models/audio/music/gpt_music.py +++ b/codes/models/audio/music/gpt_music.py @@ -27,7 +27,7 @@ class ConditioningEncoder(nn.Module): def forward(self, x): h = checkpoint(self.init, x) - h = checkpoint(self.attn, h) + h = self.attn(h return h.mean(dim=2) @@ -58,7 +58,7 @@ class UpperConditioningEncoder(nn.Module): class GptMusicLower(nn.Module): - def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4): + def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True): super().__init__() self.internal_step = 0 self.num_groups = num_target_groups @@ -73,6 +73,7 @@ class GptMusicLower(nn.Module): max(512,dim-512), max(512,dim-512)], codevector_dim=dim, codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True) + self.fp16 = fp16 # Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..) del self.target_quantizer.decoder del self.target_quantizer.up @@ -116,7 +117,7 @@ class GptMusicLower(nn.Module): h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] h = torch.cat(h, dim=-1) + upper_vector - with torch.autocast(mel.device.type): + with torch.autocast(mel.device.type, enabled=self.fp16): # Stick the conditioning embedding on the front of the input sequence. # The transformer will learn how to integrate it. # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. @@ -162,10 +163,11 @@ class GptMusicLower(nn.Module): class GptMusicUpper(nn.Module): - def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4): + def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True): super().__init__() self.internal_step = 0 self.num_groups = num_upper_groups + self.fp16 = fp16 self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64, n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True, use_cache=False) @@ -203,7 +205,7 @@ class GptMusicUpper(nn.Module): h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)] h = torch.cat(h, dim=-1) - with torch.autocast(mel.device.type): + with torch.autocast(mel.device.type, enabled=self.fp16): # Stick the conditioning embedding on the front of the input sequence. # The transformer will learn how to integrate it. # This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token. @@ -251,9 +253,9 @@ def test_lower(): base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024, prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024, dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000) - base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/28000_generator.pth', map_location=torch.device('cpu'))) + base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu'))) - model = GptMusicLower(512, 12) + model = GptMusicLower(512, 8, fp16=False) model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False) torch.save(model.state_dict(), "sample.pth") mel = torch.randn(2,256,400) @@ -273,4 +275,4 @@ def test_upper(): if __name__ == '__main__': - test_upper() \ No newline at end of file + test_lower() \ No newline at end of file