some comments

This commit is contained in:
mrq 2023-09-12 16:04:45 -05:00
parent d07c63b9d8
commit a6ae344e5b
2 changed files with 12 additions and 7 deletions

View File

@ -94,9 +94,9 @@ class AR_NAR(Base):
# is training
if n_levels == self.n_resp_levels:
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,))
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)]
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)]
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l)
quant_levels.to(device=device)
return super().forward(

View File

@ -375,15 +375,20 @@ class Base(nn.Module):
x = x[:, -1, :].unsqueeze(1)
if self.arch_type == "transformer":
x = self.sin_emb.add_pe(x)
# ensures we specify a quant_level for the transformer implementation's AdaLN
l = torch.zeros((batch_size,), dtype=torch.int32) if quant_levels is None else quant_levels
l = l.to(device)
# inject position information
x = self.sin_emb.add_pe(x)
# pass our inputs through the transformer
for block in self.blocks:
x = block(x, m, l)
elif self.arch_type == "retnet":
# pass our inputs through the RetNet
x, _ = self.retnet(x, incremental_state=state, token_embeddings=x, features_only=True)
# output projection layer with masking
x = self.classifier(x) * m
# Remove padding
@ -399,10 +404,10 @@ class Base(nn.Module):
# process each batch
for i in range(len(text_prom_list)):
# for the AR, shift the text/input prompt into the future by 1, and ignore the rolled back text token
# for the AR, shift the text/input prompt and target prompt into the future by 1, and ignore the rolled back text token
if quant_levels is None or quant_levels[i] == 0:
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
targ_list[i] = targ_list[i].clone().roll(-1, dims=0)
targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps
text_prom_list[i][-1] = self.ignore_index
targ_list[i][-1] = self.stop_token