some comments
This commit is contained in:
parent
d07c63b9d8
commit
a6ae344e5b
|
@ -94,9 +94,9 @@ class AR_NAR(Base):
|
|||
|
||||
# is training
|
||||
if n_levels == self.n_resp_levels:
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,))
|
||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)]
|
||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)]
|
||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l)
|
||||
quant_levels.to(device=device)
|
||||
|
||||
return super().forward(
|
||||
|
|
|
@ -375,15 +375,20 @@ class Base(nn.Module):
|
|||
x = x[:, -1, :].unsqueeze(1)
|
||||
|
||||
if self.arch_type == "transformer":
|
||||
x = self.sin_emb.add_pe(x)
|
||||
# ensures we specify a quant_level for the transformer implementation's AdaLN
|
||||
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
|
||||
l = l.to(device)
|
||||
# inject position information
|
||||
x = self.sin_emb.add_pe(x)
|
||||
# pass our inputs through the transformer
|
||||
for block in self.blocks:
|
||||
x = block(x, m, l)
|
||||
|
||||
elif self.arch_type == "retnet":
|
||||
# pass our inputs through the RetNet
|
||||
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
|
||||
|
||||
|
||||
# output projection layer with masking
|
||||
x = self.classifier(x) * m
|
||||
|
||||
# Remove padding
|
||||
|
@ -399,10 +404,10 @@ class Base(nn.Module):
|
|||
|
||||
# process each batch
|
||||
for i in range(len(text_prom_list)):
|
||||
# for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
|
||||
# for the AR, shift the text/input prompt and target prompt into the future by 1, and ignore the rolled back text token
|
||||
if quant_levels is None or quant_levels[i] == 0:
|
||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
||||
targ_list[i] = targ_list[i].clone().roll(-1, dims=0)
|
||||
targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps
|
||||
|
||||
text_prom_list[i][-1] = self.ignore_index
|
||||
targ_list[i][-1] = self.stop_token
|
||||
|
|
Loading…
Reference in New Issue
Block a user