un-tensor'd quant_level marker since it doesn't need to be one (I forgot why I had it as one but nothing seems to need it as a tensor that didn't already make it one)

This commit is contained in:
mrq 2024-06-07 20:46:22 -05:00
parent b0158a61d5
commit 7d6fff24f9
3 changed files with 19 additions and 17 deletions

View File

@ -142,9 +142,11 @@ class AR_NAR(Base):
index = i
return int(index)
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
#quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
quant_levels = [ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]
else:
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
#quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
quant_levels = [ random.randint(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ] # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
@ -251,7 +253,7 @@ class AR_NAR(Base):
lang_list=lang_list,
tone_list=tone_list,
quant_levels=torch.Tensor( [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ).to( device=device, dtype=torch.int32 ),
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
)
if recurrent_state is not None:

View File

@ -73,14 +73,14 @@ class MultiEmbedding(nn.Module):
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
# I imagine this is an oversight in the NAR.
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
if len(x_list) == 0:
return []
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
if self.monolithic:
w = self.weight[:1] if quant_levels is None else self.weight[1:]
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
else:
w = self.weight
@ -115,12 +115,12 @@ class AudioEmbedding_Old(nn.Module):
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor:
# prom
if quant_levels is None and xi.shape[-1] > 1:
if quant_level is None and xi.shape[-1] > 1:
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
# AR resp
elif quant_levels is None or quant_levels == 0:
elif quant_level is None or quant_level == 0:
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
# NAR resp
else:
@ -147,7 +147,7 @@ class AudioEmbedding(nn.Module):
self.sums = sums
# maintaining compat is hard
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor:
if quant_level is None:
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
@ -624,7 +624,7 @@ class Base(nn.Module):
lang_list: list[Tensor] | None = None,
tone_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None
quant_levels: int | list[int] | Tensor | None = None
):
device = text_list[0].device
batch_size = len(text_list)
@ -649,7 +649,7 @@ class Base(nn.Module):
def inputs_to_embeddings(
self,
inputs: list,
quant_levels: Tensor | None = None
quant_levels: int | list[int] | Tensor | None = None
):
x_list = []
for batch_index, batch_input in enumerate(inputs):
@ -685,7 +685,7 @@ class Base(nn.Module):
inputs: list,
logits,
quant_levels: Tensor | None = None,
quant_levels: int | list[int] | Tensor | None = None,
):
# old, "naive" way, no loss factoring
if not self.config.loss_factors:
@ -776,7 +776,7 @@ class Base(nn.Module):
# for the AR, shift sequence so that it predicts the next token
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
if quant_level is None or quant_level == 0:
if quant_level == 0:
l = self.causal_size
logit = logit[..., :-l, :]
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
@ -816,7 +816,7 @@ class Base(nn.Module):
self,
inputs: list,
quant_levels: Tensor | None = None,
quant_levels: int | list[int] | Tensor | None = None,
state: dict | list | None = None,
):
@ -874,7 +874,7 @@ class Base(nn.Module):
self,
logits: list[Tensor],
resps_list: list[Tensor],
quant_levels: Tensor | None = None,
quant_levels: int | list[int] | Tensor | None = None,
temperature: float = 1.0,
min_temperature: float = -1.0,

View File

@ -29,10 +29,10 @@ def train_feeder(engine, batch):
if engine.hyper_config.experimental:
batch_size = len(batch["text"])
if cfg.model.interleave:
quant_levels = None
quant_levels = 0
resps_list = [ resp for resp in batch["resps"] ]
else:
quant_levels = torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
quant_levels = [ random.randint( 0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels) for _ in range(batch_size) ]
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
input_ids, attention_mask = fold_inputs(