This commit is contained in:
mrq 2024-12-22 16:15:24 -06:00
parent 0d4329d2e3
commit 5f289db275
4 changed files with 33 additions and 25 deletions

View File

@ -14,7 +14,7 @@
#include <unordered_map> #include <unordered_map>
#include <iostream> #include <iostream>
#define LLAMA_CPP_EXTENDED 1 // whether the underlying llama.cpp has some extra functions #define LLAMA_CPP_EXTENDED 0 // whether the underlying llama.cpp has some extra functions
#define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch #define LLAMA_CPP_USE_VALL_E_ARCH 0 // whether the underlying llama.cpp is to use the VALL_E arch
#if !LLAMA_CPP_EXTENDED #if !LLAMA_CPP_EXTENDED

File diff suppressed because one or more lines are too long

View File

@ -105,7 +105,8 @@ def convert_to_hf_llama( state_dict, config = None, save_path = None ):
classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias ) classifier = torch.nn.Linear( model_dim, n_tokens, bias=classifier_bias )
if not split_classifiers: if not split_classifiers:
classifier.weight[:] = state_dict['module']['classifier.weight'][:] src = state_dict['module']['classifier.weight'][:]
classifier.weight[:src.shape[0], ] = src
# update ranges # update ranges
start = 0 start = 0

View File

@ -523,7 +523,7 @@ class Base(nn.Module):
classifier_l_tokens += [ 11 ] classifier_l_tokens += [ 11 ]
classifier_l_names += ["len"] classifier_l_names += ["len"]
n_vocab = 17701 if not split_classifiers else n_resp_tokens + 1 n_vocab = 17702 if not split_classifiers else n_resp_tokens + 1
self.n_vocab = n_vocab self.n_vocab = n_vocab
self.unified_position_ids = unified_position_ids self.unified_position_ids = unified_position_ids
@ -1436,6 +1436,9 @@ class Base(nn.Module):
continue continue
# offset to flattened vocab ranges # offset to flattened vocab ranges
if self.classifier is not None:
compute_acc = False
"""
if self.classifier is not None: if self.classifier is not None:
offsets = _get_offsets() offsets = _get_offsets()
@ -1454,6 +1457,7 @@ class Base(nn.Module):
if t == self.ignore_index: if t == self.ignore_index:
continue continue
token[i] += start token[i] += start
"""
if token.is_floating_point(): if token.is_floating_point():
ignored = True ignored = True
@ -1511,7 +1515,7 @@ class Base(nn.Module):
# perofrm loss calculation on the entire sequence # perofrm loss calculation on the entire sequence
if not self.config.loss_factors: if not self.config.loss_factors:
target = _join( target, torch.tensor(self.ignore_index if self.classifier is None else 17685, device=target[-1].device) ) target = _join( target, torch.tensor(self.ignore_index, device=target[-1].device) )
logit = logits[batch_index] logit = logits[batch_index]
# shift if causal # shift if causal
@ -1679,31 +1683,30 @@ class Base(nn.Module):
for i, state in enumerate( hidden_states ): for i, state in enumerate( hidden_states ):
hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ] hidden_states[i] = [ hi[:li] for hi, li in zip(hidden_states[i], map(len, x_list)) ]
# de-offset if needed
if self.classifier is not None:
offsets = _get_offsets()
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level == "stt":
k = "text"
elif classifier_level == "len":
k = "len"
else:
k = f'resps|{classifier_level}'
if k not in offsets:
continue
start, end = offsets[k]
logits[batch_index] = logits[batch_index][:, start:end]
if not training: if not training:
loss = None loss = None
stats = None stats = None
self.loss = None self.loss = None
self.stats = None self.stats = None
# de-offset if needed
if self.classifier is not None:
offsets = _get_offsets()
for batch_index, classifier_level in enumerate( classifier_levels ):
if classifier_level == "stt":
k = "text"
elif classifier_level == "len":
k = "len"
else:
k = f'resps|{classifier_level}'
if k not in offsets:
continue
start, end = offsets[k]
logits[batch_index] = logits[batch_index][:, start:start+end]
# compute loss if the target is given # compute loss if the target is given
else: else:
loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels ) loss, stats = self.calc_loss( inputs=inputs, logits=logits, quant_levels=quant_levels )