un-tensor'd quant_level marker since it doesn't need to be one (I forgot why I had it as one but nothing seems to need it as a tensor that didn't already make it one)
This commit is contained in:
parent
b0158a61d5
commit
7d6fff24f9
|
@ -142,9 +142,11 @@ class AR_NAR(Base):
|
|||
index = i
|
||||
return int(index)
|
||||
|
||||
quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
||||
#quant_levels = torch.Tensor([ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]).to(dtype=torch.int16)
|
||||
quant_levels = [ generate(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ]
|
||||
else:
|
||||
quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
#quant_levels = torch.randint(0 if self.causal else 1, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
quant_levels = [ random.randint(0 if self.causal else 1, self.n_resp_levels) for _ in range(batch_size) ] # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
|
||||
resps_list = [r[..., 0] if l == 0 else r[..., :l+1] for r, l in zip(resps_list, quant_levels)] # r if l == 0 is technically correct since only r[:, 0] is passed through the embedding, but this should save some VRAM
|
||||
|
||||
|
@ -251,7 +253,7 @@ class AR_NAR(Base):
|
|||
lang_list=lang_list,
|
||||
tone_list=tone_list,
|
||||
|
||||
quant_levels=torch.Tensor( [ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ] ).to( device=device, dtype=torch.int32 ),
|
||||
quant_levels=[ 0 for _ in range( max( batch_size, sampling_beam_width ) ) ]
|
||||
)
|
||||
|
||||
if recurrent_state is not None:
|
||||
|
|
|
@ -73,14 +73,14 @@ class MultiEmbedding(nn.Module):
|
|||
|
||||
# to-do: select quant level from given quant_levels tensor if given (i.e. through the resp_emb)
|
||||
# I imagine this is an oversight in the NAR.
|
||||
def forward(self, x_list: list[Tensor], quant_levels: Tensor | None = None) -> list[Tensor]:
|
||||
def forward(self, x_list: list[Tensor], quant_level: int | list[int] | Tensor | None = None) -> list[Tensor]:
|
||||
if len(x_list) == 0:
|
||||
return []
|
||||
|
||||
# this "strategy" will reserve the weight[0] for te AR and weight[1:] for the NAR
|
||||
# the NAR cannot share RVQ-bin level 0 with the AR for the resp_emb
|
||||
if self.monolithic:
|
||||
w = self.weight[:1] if quant_levels is None else self.weight[1:]
|
||||
w = self.weight[:1] if quant_level is None or quant_level == 0 else self.weight[1:]
|
||||
else:
|
||||
w = self.weight
|
||||
|
||||
|
@ -115,12 +115,12 @@ class AudioEmbedding_Old(nn.Module):
|
|||
# weight influencer for the influence for each level (desu this should be really useless because the weights in the embedding themselves should factor this)
|
||||
self.weight = nn.ParameterList([nn.Parameter( torch.Tensor([1]) ) for i in range(levels)]) if levels is not None else None
|
||||
|
||||
def forward(self, xi: Tensor, quant_levels: Tensor | None = None ) -> Tensor:
|
||||
def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor:
|
||||
# prom
|
||||
if quant_levels is None and xi.shape[-1] > 1:
|
||||
if quant_level is None and xi.shape[-1] > 1:
|
||||
x = sum( [ self.embeddings[k]( xi[:, k] ) * (self.weight[k] if self.weight is not None else 1) for k in range(xi.shape[-1]) ] )
|
||||
# AR resp
|
||||
elif quant_levels is None or quant_levels == 0:
|
||||
elif quant_level is None or quant_level == 0:
|
||||
x = self.embeddings[0]( xi if len(xi.shape) == 1 else xi[:, 0] )
|
||||
# NAR resp
|
||||
else:
|
||||
|
@ -147,7 +147,7 @@ class AudioEmbedding(nn.Module):
|
|||
self.sums = sums
|
||||
|
||||
# maintaining compat is hard
|
||||
def forward(self, xi: Tensor, quant_level: Tensor | None = None ) -> Tensor:
|
||||
def forward(self, xi: Tensor, quant_level: int | Tensor | None = None ) -> Tensor:
|
||||
if quant_level is None:
|
||||
quant_level = 0 if xi.dim() == 1 else xi.shape[-1] - 1
|
||||
|
||||
|
@ -624,7 +624,7 @@ class Base(nn.Module):
|
|||
lang_list: list[Tensor] | None = None,
|
||||
tone_list: list[Tensor] | None = None,
|
||||
|
||||
quant_levels: Tensor | None = None
|
||||
quant_levels: int | list[int] | Tensor | None = None
|
||||
):
|
||||
device = text_list[0].device
|
||||
batch_size = len(text_list)
|
||||
|
@ -649,7 +649,7 @@ class Base(nn.Module):
|
|||
def inputs_to_embeddings(
|
||||
self,
|
||||
inputs: list,
|
||||
quant_levels: Tensor | None = None
|
||||
quant_levels: int | list[int] | Tensor | None = None
|
||||
):
|
||||
x_list = []
|
||||
for batch_index, batch_input in enumerate(inputs):
|
||||
|
@ -685,7 +685,7 @@ class Base(nn.Module):
|
|||
inputs: list,
|
||||
logits,
|
||||
|
||||
quant_levels: Tensor | None = None,
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
):
|
||||
# old, "naive" way, no loss factoring
|
||||
if not self.config.loss_factors:
|
||||
|
@ -776,7 +776,7 @@ class Base(nn.Module):
|
|||
|
||||
# for the AR, shift sequence so that it predicts the next token
|
||||
# (the NAR predicts the next token in place, so it's not necessary to do any modifications for it)
|
||||
if quant_level is None or quant_level == 0:
|
||||
if quant_level == 0:
|
||||
l = self.causal_size
|
||||
logit = logit[..., :-l, :]
|
||||
input = input[..., l:] # shift sequence to the right by one (or causal chunk size)
|
||||
|
@ -816,7 +816,7 @@ class Base(nn.Module):
|
|||
self,
|
||||
inputs: list,
|
||||
|
||||
quant_levels: Tensor | None = None,
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
state: dict | list | None = None,
|
||||
):
|
||||
|
||||
|
@ -874,7 +874,7 @@ class Base(nn.Module):
|
|||
self,
|
||||
logits: list[Tensor],
|
||||
resps_list: list[Tensor],
|
||||
quant_levels: Tensor | None = None,
|
||||
quant_levels: int | list[int] | Tensor | None = None,
|
||||
|
||||
temperature: float = 1.0,
|
||||
min_temperature: float = -1.0,
|
||||
|
|
|
@ -29,10 +29,10 @@ def train_feeder(engine, batch):
|
|||
if engine.hyper_config.experimental:
|
||||
batch_size = len(batch["text"])
|
||||
if cfg.model.interleave:
|
||||
quant_levels = None
|
||||
quant_levels = 0
|
||||
resps_list = [ resp for resp in batch["resps"] ]
|
||||
else:
|
||||
quant_levels = torch.randint(0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels, (batch_size,))
|
||||
quant_levels = [ random.randint( 0 if "ar" in cfg.model.capabilities else 1, cfg.model.max_levels) for _ in range(batch_size) ]
|
||||
resps_list = [ [] if l == 0 else resp for l, resp in zip(quant_levels, batch["resps"]) ]
|
||||
|
||||
input_ids, attention_mask = fold_inputs(
|
||||
|
|
Loading…
Reference in New Issue
Block a user