mucked around with the loss calculation, this seems better?
This commit is contained in:
parent
fb467b19ba
commit
a539f6889f
|
@ -118,17 +118,25 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
# is training
|
# is training
|
||||||
if n_levels == self.n_resp_levels:
|
if n_levels == self.n_resp_levels:
|
||||||
|
# might be better to have this decided on the dataloader level
|
||||||
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
|
if cfg.models.ar_nar.p_ar_level == "auto" or cfg.models.ar_nar.p_ar_level is None:
|
||||||
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
quant_levels = torch.randint(0, self.n_resp_levels, (batch_size,)) # randomly select a target RVQ-bin level (0 being AR, 1+ being NAR)
|
||||||
else:
|
else:
|
||||||
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
|
quant_levels = torch.Tensor([ [ 0 if random.random() < cfg.models.ar_nar.p_ar_level else random.randint(1, self.n_resp_levels) ] for _ in range(batch_size) ])
|
||||||
|
|
||||||
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
targ_list = [r[..., l] for r, l in zip(resps_list, quant_levels)] # ensures we only have 1 RVQ-bin (our target)
|
||||||
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # yes I can just do min(1, l)
|
resps_list = [r if l == 0 else r[..., :l] for r, l in zip(resps_list, quant_levels)] # r[..., 0] is technically correct, but only r[:, 0] gets passed through the embedding
|
||||||
|
|
||||||
if cfg.experimental:
|
if cfg.experimental:
|
||||||
proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
proms_list = [ r if l == 0 else trim(r, 75 * 3) for r, l in zip(proms_list, quant_levels) ] # trim input prompt to 3 seconds
|
||||||
|
# append stop tokens for AR
|
||||||
|
for i in range(batch_size):
|
||||||
|
if quant_levels[i] > 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
resps_list[i] = torch.cat([resps_list[i], torch.Tensor([[self.stop_token] * n_levels]).to(device=device, dtype=torch.int16) ])
|
||||||
|
targ_list[i] = torch.cat([targ_list[i], torch.Tensor([self.stop_token]).to(device=device, dtype=torch.int16) ])
|
||||||
|
|
||||||
return super().forward(
|
return super().forward(
|
||||||
text_list=text_list,
|
text_list=text_list,
|
||||||
proms_list=proms_list,
|
proms_list=proms_list,
|
||||||
|
@ -294,6 +302,8 @@ def example_usage():
|
||||||
|
|
||||||
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
|
qnt = torch.load("data/qnt.pt")[0].t()[:, :cfg.models.prom_levels].to(device)
|
||||||
|
|
||||||
|
cfg.hyperparameters.gradient_accumulation_steps = 1
|
||||||
|
|
||||||
text_list = [
|
text_list = [
|
||||||
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
tokenize("ˈ a ɪ w ɪ l nˌ ɑː t ˈ æ s k ɐ sˈ ɛ k ə n d tˈ a ɪ m").to(device),
|
||||||
]
|
]
|
||||||
|
@ -323,10 +333,9 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR_NAR(**kwargs).to(device)
|
model = AR_NAR(**kwargs).to(device)
|
||||||
#steps = 500
|
steps = 250
|
||||||
#optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
steps = 1000
|
#optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
||||||
optimizer = ml.AdamW(model.parameters(), lr=1.0e-4)
|
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
|
|
||||||
torch.save( {
|
torch.save( {
|
||||||
|
|
|
@ -351,30 +351,24 @@ class Base(nn.Module):
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if targ_list is not None:
|
if targ_list is not None:
|
||||||
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
|
||||||
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
target_list = self._samplewise_merge_tensors(
|
||||||
prom_list = [ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ]
|
|
||||||
# remake input sequence
|
|
||||||
text_prom_list = self._samplewise_merge_tensors(
|
|
||||||
text_list,
|
text_list,
|
||||||
lang_list,
|
lang_list,
|
||||||
prom_list,
|
[ torch.full_like(t[..., 0], self.ignore_index) for t in proms_list ], # create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
||||||
sep=ignore_sep
|
targ_list,
|
||||||
|
sep=torch.tensor(self.ignore_index, device=device)
|
||||||
)
|
)
|
||||||
|
|
||||||
# process each batch
|
# modify only for the AR so it can properly behave like a transformer
|
||||||
for i in range(len(text_prom_list)):
|
for i in range(len(target_list)):
|
||||||
# for the AR and NAR, shift the text/input prompt into the future by 1, and ignore the rolled back token
|
if quant_levels is not None and quant_levels[i] > 0:
|
||||||
text_prom_list[i] = text_prom_list[i].roll(-1, dims=0)
|
continue
|
||||||
text_prom_list[i][-1] = self.ignore_index
|
|
||||||
|
|
||||||
# for the AR, shift the target response into the future by 1, and ignore the rolled back text token
|
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
||||||
if quant_levels is None or quant_levels[i] == 0:
|
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
||||||
targ_list[i] = targ_list[i].clone().roll(-1, dims=0) # clone ensures it's not an aliased copy/view of resps
|
|
||||||
targ_list[i][-1] = self.stop_token
|
|
||||||
|
|
||||||
# create the new target sequence to compute the loss against
|
target = torch.cat( target_list )
|
||||||
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )
|
|
||||||
inputs = torch.cat( logits )
|
inputs = torch.cat( logits )
|
||||||
|
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user