what has science done
This commit is contained in:
parent
cebacc2303
commit
02fc67d00c
|
@ -396,11 +396,11 @@ def load_engines(training=True, **model_kwargs):
|
||||||
if cfg.lora is not None:
|
if cfg.lora is not None:
|
||||||
key_name = cfg.lora.full_name
|
key_name = cfg.lora.full_name
|
||||||
|
|
||||||
kwargs['id'] = 'job'
|
kwargs['id'] = f'{key_name}-job'
|
||||||
kwargs['resume'] = 'allow'
|
kwargs['resume'] = 'allow'
|
||||||
if world_size() > 1:
|
if world_size() > 1:
|
||||||
kwargs["group"] = "DDP"
|
kwargs["group"] = "DDP"
|
||||||
kwargs['id'] = f'job-{global_rank()}'
|
kwargs['id'] = f'{key_name}-job-{global_rank()}'
|
||||||
|
|
||||||
|
|
||||||
engine.wandb = wandb.init(project=key_name, **kwargs)
|
engine.wandb = wandb.init(project=key_name, **kwargs)
|
||||||
|
|
|
@ -171,7 +171,7 @@ class AR_NAR(Base):
|
||||||
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
resps_list[i][t, l] = clamp(token + offset, 1, 1022) # +- 1
|
||||||
|
|
||||||
# only apply stop token for RVQ level 0
|
# only apply stop token for RVQ level 0
|
||||||
if quant_level <= 0 and timesteps[i] is None:
|
if (self.version < 7 and quant_level <= 0 and timesteps[i] is None) or (self.version >= 7 and timesteps[i] is None):
|
||||||
# append stop tokens for AR
|
# append stop tokens for AR
|
||||||
if task not in text_task:
|
if task not in text_task:
|
||||||
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
|
||||||
|
|
|
@ -735,6 +735,7 @@ class Base(nn.Module):
|
||||||
if self.version >= 6:
|
if self.version >= 6:
|
||||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||||
|
|
||||||
|
self.resp_parallel_training = True #
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
pd_model = d_model // 4
|
pd_model = d_model // 4
|
||||||
pd_ffn = pd_model * 4
|
pd_ffn = pd_model * 4
|
||||||
|
@ -763,7 +764,7 @@ class Base(nn.Module):
|
||||||
self.n_resp_levels,
|
self.n_resp_levels,
|
||||||
d_model,
|
d_model,
|
||||||
dict(
|
dict(
|
||||||
vocab_size=n_audio_tokens,
|
vocab_size=n_audio_tokens + 1,
|
||||||
hidden_size=pd_model,
|
hidden_size=pd_model,
|
||||||
max_position_embeddings=max_position_embeddings,
|
max_position_embeddings=max_position_embeddings,
|
||||||
intermediate_size=pd_ffn,
|
intermediate_size=pd_ffn,
|
||||||
|
@ -1134,7 +1135,7 @@ class Base(nn.Module):
|
||||||
inputs[i].append( ( "resp", resps_list[i] ) )
|
inputs[i].append( ( "resp", resps_list[i] ) )
|
||||||
|
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
classifier_level = f"NAR:{quant_level}:{quant_level}"
|
classifier_level = f"{'N' if timestep is not None else ''}AR:{quant_level}:{quant_level}"
|
||||||
|
|
||||||
inputs[i].append( ("classifier_level", classifier_level) )
|
inputs[i].append( ("classifier_level", classifier_level) )
|
||||||
# Audio length prediction task
|
# Audio length prediction task
|
||||||
|
@ -1562,7 +1563,7 @@ class Base(nn.Module):
|
||||||
else:
|
else:
|
||||||
accuracy_metric = MulticlassAccuracy(
|
accuracy_metric = MulticlassAccuracy(
|
||||||
logit.shape[-1],
|
logit.shape[-1],
|
||||||
top_k = 10,
|
top_k = min(logit.shape[0], 10),
|
||||||
average="micro",
|
average="micro",
|
||||||
multidim_average="global",
|
multidim_average="global",
|
||||||
ignore_index = -100
|
ignore_index = -100
|
||||||
|
@ -1663,8 +1664,23 @@ class Base(nn.Module):
|
||||||
if loss_factor == 0.0:
|
if loss_factor == 0.0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# cringe way to deduce "requested" level
|
||||||
|
level = quant_level
|
||||||
|
for i in range( self.n_resp_levels ):
|
||||||
|
if classifier_level == f'NAR:{i}:{i}':
|
||||||
|
level = i
|
||||||
|
break
|
||||||
|
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
nll, metrics = _calc_loss( logits[batch_index][start:end], token.long(), causal )
|
||||||
|
|
||||||
|
if name == "resp":
|
||||||
|
name = f'{name}[{quant_level}]'
|
||||||
|
elif not self.resp_parallel_training:
|
||||||
|
if name == "resp":
|
||||||
|
name = f'{name}[{level}]'
|
||||||
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
|
nll, metrics = _calc_loss( logits[batch_index][level][start:end], sequence.long(), causal )
|
||||||
else:
|
else:
|
||||||
nlls = []
|
nlls = []
|
||||||
accs = []
|
accs = []
|
||||||
|
@ -1673,15 +1689,29 @@ class Base(nn.Module):
|
||||||
sequence = token if token.dim() <= 1 else token[:, level]
|
sequence = token if token.dim() <= 1 else token[:, level]
|
||||||
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
nll, metrics = _calc_loss( logit[start:end], sequence.long(), causal )
|
||||||
|
|
||||||
if nll:
|
if name == "resp":
|
||||||
nlls.append( nll )
|
if nll is not None:
|
||||||
if metrics:
|
if f'{name}[{level}].nll' not in loss:
|
||||||
accs.append( metrics )
|
loss[f'{name}[{level}].nll'] = []
|
||||||
|
loss[f"{name}[{level}].nll"].append( nll * loss_factor )
|
||||||
|
|
||||||
if nlls:
|
if metrics is not None:
|
||||||
nll = sum(nlls) / len(nlls)
|
if f'{name}[{level}].acc' not in stats:
|
||||||
if accs:
|
stats[f'{name}[{level}].acc'] = []
|
||||||
accs = sum(accs) / len(accs)
|
stats[f"{name}[{level}].acc"].append( metrics )
|
||||||
|
|
||||||
|
nll = None
|
||||||
|
metrics = None
|
||||||
|
else:
|
||||||
|
if nll:
|
||||||
|
nlls.append( nll )
|
||||||
|
if metrics:
|
||||||
|
accs.append( metrics )
|
||||||
|
else:
|
||||||
|
if nlls:
|
||||||
|
nll = sum(nlls) / len(nlls)
|
||||||
|
if accs:
|
||||||
|
accs = sum(accs) / len(accs)
|
||||||
|
|
||||||
if nll is not None:
|
if nll is not None:
|
||||||
if f'{name}.nll' not in loss:
|
if f'{name}.nll' not in loss:
|
||||||
|
@ -1701,6 +1731,16 @@ class Base(nn.Module):
|
||||||
if logits[batch_index].dim() < 3:
|
if logits[batch_index].dim() < 3:
|
||||||
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
sequence = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||||
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
nll, metrics = _calc_loss( logits[batch_index], sequence, causal )
|
||||||
|
elif not self.resp_parallel_training:
|
||||||
|
# cringe way to deduce "requested" level
|
||||||
|
level = 0
|
||||||
|
for i in range( self.n_resp_levels ):
|
||||||
|
if classifier_level == f'NAR:{i}:{i}':
|
||||||
|
level = i
|
||||||
|
break
|
||||||
|
sequence = [ x if x.dim() <= 1 else x[:, level] for x in target ]
|
||||||
|
sequence = _join( sequence, torch.tensor(self.ignore_index, device=sequence[-1].device) )
|
||||||
|
nll, metrics = _calc_loss( logits[batch_index][level], sequence.long(), causal )
|
||||||
else:
|
else:
|
||||||
nlls = []
|
nlls = []
|
||||||
accs = []
|
accs = []
|
||||||
|
@ -1803,7 +1843,7 @@ class Base(nn.Module):
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
logits = [ logit for logit in logits ]
|
logits = [ logit for logit in logits ]
|
||||||
|
|
||||||
audio_decoder_levels = [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
audio_decoder_levels = [ f"AR:{i}:{i}" for i in range(self.n_resp_levels) ] + [ f"NAR:{i}:{i}" for i in range(self.n_resp_levels) ]
|
||||||
|
|
||||||
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
|
decoders_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level in audio_decoder_levels ]
|
||||||
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
|
classifiers_indices = [ batch_index for batch_index, level in enumerate( classifier_levels ) if level not in audio_decoder_levels ]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user