diff --git a/codes/models/arch_util.py b/codes/models/arch_util.py index 19791ecd..090c946f 100644 --- a/codes/models/arch_util.py +++ b/codes/models/arch_util.py @@ -367,12 +367,14 @@ class ResBlock(nn.Module): up=False, down=False, kernel_size=3, + checkpointing_enabled=True, ): super().__init__() self.channels = channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv + self.checkpointing_enabled = checkpointing_enabled padding = 1 if kernel_size == 3 else 2 self.in_layers = nn.Sequential( @@ -417,9 +419,12 @@ class ResBlock(nn.Module): :param x: an [N x C x ...] Tensor of features. :return: an [N x C x ...] Tensor of outputs. """ - return checkpoint( - self._forward, x - ) + if self.checkpointing_enabled: + return checkpoint( + self._forward, x + ) + else: + return self._forward(x) def _forward(self, x): if self.updown: @@ -1017,4 +1022,4 @@ def gather_2d(input, index): result = result.squeeze() if b == 1: result = result.unsqueeze(0) - return result \ No newline at end of file + return result diff --git a/codes/models/audio/music/cheater_gen_ar.py b/codes/models/audio/music/cheater_gen_ar.py index bc8ce212..096e1619 100644 --- a/codes/models/audio/music/cheater_gen_ar.py +++ b/codes/models/audio/music/cheater_gen_ar.py @@ -31,9 +31,9 @@ class ConditioningEncoder(nn.Module): rotary_pos_emb=True, zero_init_branch_output=True, ff_mult=2, + do_checkpointing=do_checkpointing ) self.dim = embedding_dim - self.do_checkpointing = do_checkpointing def forward(self, x): h = self.init(x).permute(0,2,1) @@ -122,4 +122,4 @@ def test_ar(): if __name__ == '__main__': - test_ar() \ No newline at end of file + test_ar() diff --git a/codes/models/audio/music/encoders.py b/codes/models/audio/music/encoders.py new file mode 100644 index 00000000..c2fd3a66 --- /dev/null +++ b/codes/models/audio/music/encoders.py @@ -0,0 +1,44 @@ +import torch +import torch.nn.functional as F +from torch import nn +from transformers import GPT2Config, GPT2Model + +from models.arch_util import AttentionBlock, ResBlock +from models.audio.tts.lucidrains_dvae import DiscreteVAE +from trainer.networks import register_model +from utils.util import opt_get, ceil_multiple, print_network + + +class ResEncoder16x(nn.Module): + def __init__(self, + spec_dim, + hidden_dim, + embedding_dim, + checkpointing_enabled=True, + ): + super().__init__() + attn = [] + def edim(m): + dd = min(spec_dim + m * 128, hidden_dim) + return ceil_multiple(dd, 8) + self.downsampler = nn.Sequential( + ResBlock(spec_dim, out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(3), out_channels=edim(3), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(4), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled)) + self.encoder = nn.Sequential( + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + nn.GroupNorm(8, hidden_dim), + nn.SiLU(), + nn.Conv1d(hidden_dim, embedding_dim, 1), + nn.Tanh(), + ) + + def forward(self, x): + h = self.downsampler(x) + h = self.encoder(h) + return h diff --git a/codes/models/audio/music/gpt_music2.py b/codes/models/audio/music/gpt_music2.py index 1e2afdd5..acaad51f 100644 --- a/codes/models/audio/music/gpt_music2.py +++ b/codes/models/audio/music/gpt_music2.py @@ -14,6 +14,7 @@ class UpperEncoder(nn.Module): spec_dim, hidden_dim, embedding_dim, + checkpointing_enabled=True, ): super().__init__() attn = [] @@ -21,18 +22,18 @@ class UpperEncoder(nn.Module): dd = min(spec_dim + m * 128, hidden_dim) return ceil_multiple(dd, 8) self.downsampler = nn.Sequential( - ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True), - ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True), - ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True), - ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1), - ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True)) + ResBlock(spec_dim, out_channels=edim(1), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(1), out_channels=edim(2), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(2), out_channels=edim(3), use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(3), out_channels=edim(4), use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), + ResBlock(edim(4), out_channels=hidden_dim, use_conv=True, dims=1, down=True, checkpointing_enabled=checkpointing_enabled)) self.encoder = nn.Sequential( AttentionBlock(hidden_dim, 4, do_activation=True), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), AttentionBlock(hidden_dim, 4, do_activation=True), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), AttentionBlock(hidden_dim, 4, do_activation=True), - ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1), + ResBlock(hidden_dim, out_channels=hidden_dim, use_conv=True, dims=1, checkpointing_enabled=checkpointing_enabled), nn.GroupNorm(8, hidden_dim), nn.SiLU(), nn.Conv1d(hidden_dim, embedding_dim, 1), @@ -45,6 +46,8 @@ class UpperEncoder(nn.Module): return h + + class GptMusicLower(nn.Module): def __init__(self, dim, layers, encoder_out_dim, dropout=0, num_target_vectors=8192, fp16=True, num_vaes=4, vqargs={}): super().__init__() @@ -170,4 +173,4 @@ def test_lower(): if __name__ == '__main__': - test_lower() \ No newline at end of file + test_lower() diff --git a/codes/models/audio/music/transformer_diffusion12.py b/codes/models/audio/music/transformer_diffusion12.py index 10ede292..f750d96c 100644 --- a/codes/models/audio/music/transformer_diffusion12.py +++ b/codes/models/audio/music/transformer_diffusion12.py @@ -456,15 +456,14 @@ class TransformerDiffusionWithMultiPretrainedVqvae(nn.Module): if hasattr(p, 'grad') and p.grad is not None: p.grad *= .2 - class TransformerDiffusionWithCheaterLatent(nn.Module): - def __init__(self, freeze_encoder_until=None, **kwargs): + def __init__(self, freeze_encoder_until=None, checkpoint_encoder=True, **kwargs): super().__init__() self.internal_step = 0 self.freeze_encoder_until = freeze_encoder_until self.diff = TransformerDiffusion(**kwargs) - self.encoder = UpperEncoder(256, 1024, 256) - self.encoder = self.encoder.eval() + from models.audio.music.encoders import ResEncoder16x + self.encoder = ResEncoder16x(256, 1024, 256, checkpointing_enabled=checkpoint_encoder) def forward(self, x, timesteps, truth_mel, conditioning_input=None, disable_diversity=False, conditioning_free=False): unused_parameters = [] diff --git a/codes/trainer/injectors/audio_injectors.py b/codes/trainer/injectors/audio_injectors.py index 53cb1815..ef8ecbe3 100644 --- a/codes/trainer/injectors/audio_injectors.py +++ b/codes/trainer/injectors/audio_injectors.py @@ -98,11 +98,24 @@ class RandomAudioCropInjector(Injector): self.max_crop_sz = opt['max_crop_size'] self.lengths_key = opt['lengths_key'] self.crop_start_key = opt['crop_start_key'] + self.rand_buffer_ptr=9999 + self.rand_buffer_sz=5000 def forward(self, state): crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz) inp = state[self.input] + if torch.distributed.get_world_size() > 1: + # All processes should agree, otherwise all processes wait to process max_crop_sz (effectively). But agreeing too often + # is expensive, so agree on a "chunk" at a time. + if self.rand_buffer_ptr >= self.rand_buffer_sz: + self.rand_buffer = torch.randint(self.min_crop_sz, self.max_crop_sz, (self.rand_buffer_sz,), dtype=torch.long, device=inp.device) + torch.distributed.broadcast(self.rand_buffer, 0) + self.rand_buffer_ptr = 0 + crop_sz = self.rand_buffer[self.rand_buffer_ptr] + self.rand_buffer_ptr += 1 + else: + crop_sz = random.randint(self.min_crop_sz, self.max_crop_sz) if self.lengths_key is not None: lens = state[self.lengths_key] len = torch.min(lens) @@ -445,4 +458,4 @@ class MusicCheaterArInjector(Injector): self.needs_move = False with torch.no_grad(): latents = self.cheater_ar(codes, cond, return_latent=True) - return {self.output: latents} \ No newline at end of file + return {self.output: latents}