un-tensor'd quant_level marker since it doesn't need to be one (I forgot why I had it as one but nothing seems to need it as a tensor that didn't already make it one)

This commit is contained in:
mrq 2024-06-07 20:46:22 -05:00
parent b0158a61d5
commit 7d6fff24f9
3 changed files with 19 additions and 17 deletions

View File

@ -142,9 +142,11 @@ class AR_NAR(Base):
index = i index = i
return int(index) return int(index)
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16) #quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
quant_levels = [ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]
else: else:
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR) #quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ] # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
@ -251,7 +253,7 @@ class AR_NAR(Base):
lang_list=lang_list, lang_list=lang_list,
tone_list=tone_list, tone_list=tone_list,
quant_levels=torch.Tensor( [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ).to( device=device, dtype=torch.int32 ), quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
) )
if recurrent_state is not None: if recurrent_state is not None:

View File

@ -73,14 +73,14 @@ class MultiEmbedding(nn.Module):
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb) # to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR. # I imagine this is an oversight in the NAR.
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]: def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
if len(x_list) == 0: if len(x_list) == 0:
return [] return []
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR # this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb # the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
if self.monolithic: if self.monolithic:
w = self.weight[:1] if quant_levels is None else self.weight[1:] w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
else: else:
w = self.weight w = self.weight
@ -115,12 +115,12 @@ class AudioEmbedding_Old(nn.Module):
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this) # weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor: def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor:
# prom # prom
if quant_levels is None and xi.shape[-1] > 1: if quant_level is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] ) x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp # AR resp
elif quant_levels is None or quant_levels == 0: elif quant_level is None or quant_level == 0:
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] ) x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
# NAR resp # NAR resp
else: else:
@ -147,7 +147,7 @@ class AudioEmbedding(nn.Module):
self.sums = sums self.sums = sums
# maintaining compat is hard # maintaining compat is hard
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor: def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor:
if quant_level is None: if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1 quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
@ -624,7 +624,7 @@ class Base(nn.Module):
lang_list: list[Tensor] | None = None, lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None, tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None quant_levels: int | list[int] | Tensor | None = None
): ):
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
@ -649,7 +649,7 @@ class Base(nn.Module):
def inputs_to_embeddings( def inputs_to_embeddings(
self, self,
inputs: list, inputs: list,
quant_levels: Tensor | None = None quant_levels: int | list[int] | Tensor | None = None
): ):
x_list = [] x_list = []
for batch_index, batch_input in enumerate(inputs): for batch_index, batch_input in enumerate(inputs):
@ -685,7 +685,7 @@ class Base(nn.Module):
inputs: list, inputs: list,
logits, logits,
quant_levels: Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
): ):
# old, "naive" way, no loss factoring # old, "naive" way, no loss factoring
if not self.config.loss_factors: if not self.config.loss_factors:
@ -776,7 +776,7 @@ class Base(nn.Module):
# for the AR, shift sequence so that it predicts the next token # for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it) # (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level is None or quant_level == 0: if quant_level == 0:
l = self.causal_size l = self.causal_size
logit = logit[..., :-l, :] logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size) input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
@ -816,7 +816,7 @@ class Base(nn.Module):
self, self,
inputs: list, inputs: list,
quant_levels: Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None, state: dict | list | None = None,
): ):
@ -874,7 +874,7 @@ class Base(nn.Module):
self, self,
logits: list[Tensor], logits: list[Tensor],
resps_list: list[Tensor], resps_list: list[Tensor],
quant_levels: Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
temperature: float = 1.0, temperature: float = 1.0,
min_temperature: float = -1.0, min_temperature: float = -1.0,

View File

@ -29,10 +29,10 @@ def train_feeder(engine, batch):
if engine.hyper_config.experimental: if engine.hyper_config.experimental:
batch_size = len(batch["text"]) batch_size = len(batch["text"])
if cfg.model.interleave: if cfg.model.interleave:
quant_levels = None quant_levels = 0
resps_list = [ resp for resp in batch["resps"] ] resps_list = [ resp for resp in batch["resps"] ]
else: else:
quant_levels = torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,)) quant_levels = [ random.randint( 0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels) for _ in range(batch_size) ]
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ] resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
input_ids, attention_mask = fold_inputs( input_ids, attention_mask = fold_inputs(