ugh
This commit is contained in:
parent
0d4329d2e3
commit
5f289db275
|
@ -14,7 +14,7 @@
|
|||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions
|
||||
#define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
|
||||
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
|
||||
|
||||
#if !LLAMA_CPP_EXTENDED
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -105,7 +105,8 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
|
|||
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
|
||||
|
||||
if not split_classifiers:
|
||||
classifier.weight[:] = state_dict['module']['classifier.weight'][:]
|
||||
src = state_dict['module']['classifier.weight'][:]
|
||||
classifier.weight[:src.shape[0], ] = src
|
||||
|
||||
# update ranges
|
||||
start = 0
|
||||
|
|
|
@ -523,7 +523,7 @@ class Base(nn.Module):
|
|||
classifier_l_tokens += [ 11 ]
|
||||
classifier_l_names += ["len"]
|
||||
|
||||
n_vocab = 17701 if not split_classifiers else n_resp_tokens + 1
|
||||
n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
|
||||
|
||||
self.n_vocab = n_vocab
|
||||
self.unified_position_ids = unified_position_ids
|
||||
|
@ -1436,6 +1436,9 @@ class Base(nn.Module):
|
|||
continue
|
||||
|
||||
# offset to flattened vocab ranges
|
||||
if self.classifier is not None:
|
||||
compute_acc = False
|
||||
"""
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
|
||||
|
@ -1454,6 +1457,7 @@ class Base(nn.Module):
|
|||
if t == self.ignore_index:
|
||||
continue
|
||||
token[i] += start
|
||||
"""
|
||||
|
||||
if token.is_floating_point():
|
||||
ignored = True
|
||||
|
@ -1511,7 +1515,7 @@ class Base(nn.Module):
|
|||
|
||||
# perofrm loss calculation on the entire sequence
|
||||
if not self.config.loss_factors:
|
||||
target = _join( target, torch.tensor(self.ignore_index if self.classifier is None else 17685, device=target[-1].device) )
|
||||
target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
|
||||
logit = logits[batch_index]
|
||||
|
||||
# shift if causal
|
||||
|
@ -1679,31 +1683,30 @@ class Base(nn.Module):
|
|||
for i, state in enumerate( hidden_states ):
|
||||
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
|
||||
|
||||
# de-offset if needed
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
if classifier_level == "stt":
|
||||
k = "text"
|
||||
elif classifier_level == "len":
|
||||
k = "len"
|
||||
else:
|
||||
k = f'resps|{classifier_level}'
|
||||
|
||||
if k not in offsets:
|
||||
continue
|
||||
|
||||
start, end = offsets[k]
|
||||
|
||||
logits[batch_index] = logits[batch_index][:, start:end]
|
||||
|
||||
if not training:
|
||||
loss = None
|
||||
stats = None
|
||||
|
||||
self.loss = None
|
||||
self.stats = None
|
||||
|
||||
# de-offset if needed
|
||||
if self.classifier is not None:
|
||||
offsets = _get_offsets()
|
||||
for batch_index, classifier_level in enumerate( classifier_levels ):
|
||||
if classifier_level == "stt":
|
||||
k = "text"
|
||||
elif classifier_level == "len":
|
||||
k = "len"
|
||||
else:
|
||||
k = f'resps|{classifier_level}'
|
||||
|
||||
if k not in offsets:
|
||||
continue
|
||||
|
||||
start, end = offsets[k]
|
||||
|
||||
logits[batch_index] = logits[batch_index][:, start:start+end]
|
||||
|
||||
# compute loss if the target is given
|
||||
else:
|
||||
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )
|
||||
|
|
Loading…
Reference in New Issue
Block a user