forked from mrq/DL-Art-School
checkpointing bugs, smh
This commit is contained in:
parent
c61cd64bc9
commit
dee2b72786
|
@ -27,7 +27,7 @@ class ConditioningEncoder(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
h = checkpoint(self.init, x)
|
||||
h = checkpoint(self.attn, h)
|
||||
h = self.attn(h
|
||||
return h.mean(dim=2)
|
||||
|
||||
|
||||
|
@ -58,7 +58,7 @@ class UpperConditioningEncoder(nn.Module):
|
|||
|
||||
|
||||
class GptMusicLower(nn.Module):
|
||||
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4):
|
||||
def __init__(self, dim, layers, dropout=0, num_target_vectors=512, num_target_groups=2, num_upper_vectors=64, num_upper_groups=4, fp16=True):
|
||||
super().__init__()
|
||||
self.internal_step = 0
|
||||
self.num_groups = num_target_groups
|
||||
|
@ -73,6 +73,7 @@ class GptMusicLower(nn.Module):
|
|||
max(512,dim-512),
|
||||
max(512,dim-512)], codevector_dim=dim,
|
||||
codebook_size=num_upper_vectors, codebook_groups=num_upper_groups, expressive_downsamples=True)
|
||||
self.fp16 = fp16
|
||||
# Following are unused quantizer constructs we delete to avoid DDP errors (and to be efficient.. of course..)
|
||||
del self.target_quantizer.decoder
|
||||
del self.target_quantizer.up
|
||||
|
@ -116,7 +117,7 @@ class GptMusicLower(nn.Module):
|
|||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1) + upper_vector
|
||||
|
||||
with torch.autocast(mel.device.type):
|
||||
with torch.autocast(mel.device.type, enabled=self.fp16):
|
||||
# Stick the conditioning embedding on the front of the input sequence.
|
||||
# The transformer will learn how to integrate it.
|
||||
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||
|
@ -162,10 +163,11 @@ class GptMusicLower(nn.Module):
|
|||
|
||||
|
||||
class GptMusicUpper(nn.Module):
|
||||
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4):
|
||||
def __init__(self, dim, layers, dropout=0, num_upper_vectors=64, num_upper_groups=4, fp16=True):
|
||||
super().__init__()
|
||||
self.internal_step = 0
|
||||
self.num_groups = num_upper_groups
|
||||
self.fp16 = fp16
|
||||
self.config = GPT2Config(vocab_size=1, n_positions=8192, n_embd=dim, n_layer=layers, n_head=dim//64,
|
||||
n_inner=dim*2, attn_pdrop=dropout, resid_pdrop=dropout, gradient_checkpointing=True,
|
||||
use_cache=False)
|
||||
|
@ -203,7 +205,7 @@ class GptMusicUpper(nn.Module):
|
|||
h = [embedding(inputs[:, :, i]) for i, embedding in enumerate(self.embeddings)]
|
||||
h = torch.cat(h, dim=-1)
|
||||
|
||||
with torch.autocast(mel.device.type):
|
||||
with torch.autocast(mel.device.type, enabled=self.fp16):
|
||||
# Stick the conditioning embedding on the front of the input sequence.
|
||||
# The transformer will learn how to integrate it.
|
||||
# This statement also serves to pre-pad the inputs by one token, which is the basis of the next-token-prediction task. IOW: this is the "START" token.
|
||||
|
@ -251,9 +253,9 @@ def test_lower():
|
|||
base_diff = TransformerDiffusionWithQuantizer(in_channels=256, out_channels=512, model_channels=2048, block_channels=1024,
|
||||
prenet_channels=1024, prenet_layers=6, num_layers=16, input_vec_dim=1024,
|
||||
dropout=.1, unconditioned_percentage=0, freeze_quantizer_until=6000)
|
||||
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/28000_generator.pth', map_location=torch.device('cpu')))
|
||||
base_diff.load_state_dict(torch.load('x:/dlas/experiments/train_music_diffusion_tfd8/models/47500_generator.pth', map_location=torch.device('cpu')))
|
||||
|
||||
model = GptMusicLower(512, 12)
|
||||
model = GptMusicLower(512, 8, fp16=False)
|
||||
model.target_quantizer.load_state_dict(base_diff.quantizer.state_dict(), strict=False)
|
||||
torch.save(model.state_dict(), "sample.pth")
|
||||
mel = torch.randn(2,256,400)
|
||||
|
@ -273,4 +275,4 @@ def test_upper():
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_upper()
|
||||
test_lower()
|
Loading…
Reference in New Issue
Block a user