@ -5,6 +5,7 @@ import torch.nn as nn
import torch . nn . functional as F
from models . diffusion . nn import timestep_embedding
from models . lucidrains . vq import VectorQuantize
from models . lucidrains . x_transformers import FeedForward , Attention , Decoder , RMSScaleShiftNorm
from trainer . networks import register_model
from utils . util import checkpoint
@ -16,55 +17,36 @@ class SelfClassifyingHead(nn.Module):
self . seq_len = seq_len
self . num_classes = classes
self . temperature = init_temperature
self . dec = Decoder ( dim = dim , depth = head_depth , heads = 2 , ff_dropout = dropout , ff_mult = 2 , attn_dropout = dropout ,
use_rmsnorm = True , ff_glu = True , rotary_pos_emb = True )
self . to_classes = nn . Linear ( dim , classes )
self . feedback_codebooks = nn . Linear ( classes , dim , bias = False )
self . codebooks = nn . Linear ( classes , out_dim , bias = False )
@staticmethod
def _compute_perplexity ( probs , mask = None ) :
if mask is not None :
mask_extended = mask . flatten ( ) [ : , None , None ] . expand ( probs . shape )
probs = torch . where ( mask_extended , probs , torch . zeros_like ( probs ) )
marginal_probs = probs . sum ( dim = 0 ) / mask . sum ( )
else :
marginal_probs = probs . mean ( dim = 0 )
perplexity = torch . exp ( - torch . sum ( marginal_probs * torch . log ( marginal_probs + 1e-7 ) , dim = - 1 ) ) . sum ( )
return perplexity
self . dec = Decoder ( dim = dim , depth = head_depth , heads = 4 , ff_dropout = dropout , ff_mult = 2 , attn_dropout = dropout ,
use_rmsnorm = True , ff_glu = True , do_checkpointing = False )
self . quantizer = VectorQuantize ( dim , classes , codebook_dim = 32 , use_cosine_sim = True , threshold_ema_dead_code = 2 ,
sample_codebook_temp = init_temperature )
self . to_output = nn . Linear ( dim , out_dim )
def do_ar_step ( self , x , used_codes ) :
h = self . dec ( x )
h = self . to_classes ( h [ : , - 1 ] )
for uc in used_codes :
mask = torch . arange ( 0 , self . num_classes , device = x . device ) . unsqueeze ( 0 ) . repeat ( x . shape [ 0 ] , 1 ) == uc . unsqueeze ( 1 )
h [ mask ] = - torch . inf
c = F . gumbel_softmax ( h , tau = self . temperature , hard = self . temperature == 1 ) \
soft_c = torch . softmax ( h , dim = - 1 )
perplexity = self . _compute_perplexity ( soft_c )
return c , perplexity
h , c , _ = self . quantizer ( h [ : , - 1 ] , used_codes )
return h , c
def forward ( self , x ) :
with torch . no_grad ( ) :
# Force one of the codebook weights to zero, allowing the model to "skip" any classes it chooses.
self . codebooks. weight . data [ : , 0 ] = 0
self . quantizer . _codebook . embed . data [ 0 ] = 0
# manually perform ar regression over sequence_length=self.seq_len
stack = [ x ]
outputs = [ ]
results = [ ]
codes = [ ]
total_perplexity = 0
for i in range ( self . seq_len ) :
nc , perp = checkpoint ( functools . partial ( self . do_ar_step , used_codes = codes ) , torch . stack ( stack , dim = 1 ) )
codes . append ( nc . argmax ( - 1 ) )
stack . append ( self . feedback_codebooks ( nc ) )
outputs . append ( self . codebooks ( nc ) )
h , c = checkpoint ( functools . partial ( self . do_ar_step , used_codes = codes ) , torch . stack ( stack , dim = 1 ) )
c_mask = c
c_mask [ c == 0 ] = - 1 # Mask this out because we want code=0 to be capable of being repeated.
codes . append ( c )
stack . append ( h . detach ( ) ) # Detach here to avoid piling up gradients from autoregression. We really just want the gradients to flow to the selected class embeddings and the selector for those classes.
outputs . append ( self . to_output ( h ) )
results . append ( torch . stack ( outputs , dim = 1 ) . sum ( 1 ) )
total_perplexity = total_perplexity + perp
return results , total_perplexity / self . seq_len , torch . cat ( codes , dim = - 1 )
return results , torch . cat ( codes , dim = 0 )
class VectorResBlock ( nn . Module ) :
@ -83,7 +65,6 @@ class InstrumentQuantizer(nn.Module):
def __init__ ( self , op_dim , dim , num_classes , enc_depth , head_depth , class_seq_len = 5 , dropout = .1 ,
min_temp = 1 , max_temp = 10 , temp_decay = .999 ) :
"""
Args :
op_dim :
dim :
@ -100,6 +81,7 @@ class InstrumentQuantizer(nn.Module):
self . op_dim = op_dim
self . proj = nn . Linear ( op_dim , dim )
self . encoder = nn . ModuleList ( [ VectorResBlock ( dim , dropout ) for _ in range ( enc_depth ) ] )
self . final_bn = nn . BatchNorm1d ( dim )
self . heads = SelfClassifyingHead ( dim , num_classes , op_dim , head_depth , class_seq_len , dropout , max_temp )
self . min_gumbel_temperature = min_temp
self . max_gumbel_temperature = max_temp
@ -117,14 +99,15 @@ class InstrumentQuantizer(nn.Module):
h = self . proj ( f )
for lyr in self . encoder :
h = lyr ( h )
h = self . final_bn ( h . unsqueeze ( - 1 ) ) . squeeze ( - 1 )
reconstructions , perplexity, codes = self . heads ( h )
reconstructions , codes = self . heads ( h )
reconstruction_losses = torch . stack ( [ F . mse_loss ( r . reshape ( b , s , c ) , px ) for r in reconstructions ] )
r_follow = torch . arange ( 1 , reconstruction_losses . shape [ 0 ] + 1 , device = x . device )
reconstruction_losses = ( reconstruction_losses * r_follow / r_follow . shape [ 0 ] )
self . log_codes ( codes )
return reconstruction_losses , perplexity
return reconstruction_losses
def log_codes ( self , codes ) :
if self . internal_step % 5 == 0 :
@ -139,13 +122,13 @@ class InstrumentQuantizer(nn.Module):
def get_debug_values ( self , step , __ ) :
if self . total_codes > 0 :
return { ' histogram_codes ' : self . codes [ : self . total_codes ] ,
' temperature ' : self . heads . temperature }
' temperature ' : self . heads . quantizer. _codebook . sample_codebook_temp }
else :
return { }
def update_for_step ( self , step , * args ) :
self . internal_step = step
self . heads . temperature = max (
self . heads . quantizer. _codebook . sample_codebook_temp = max (
self . max_gumbel_temperature * self . gumbel_temperature_decay * * step ,
self . min_gumbel_temperature ,
)