experimental implementation of STT (need to actually test on a model, test trainer seems to work)

This commit is contained in:
mrq 2024-09-05 20:43:20 -05:00
parent d319d33368
commit 54547b74d8
6 changed files with 147 additions and 51 deletions

View File

@ -960,6 +960,19 @@ class NaiveTokenizer:
# tokenize # tokenize
return [*map(symmap.get, phones)] return [*map(symmap.get, phones)]
def decode( self, t ):
s = ""
symmap = self.get_vocab()
reverse_symmap = {}
for k, v in symmap.items():
reverse_symmap[v] = k
for i, token in enumerate( t ):
s += reverse_symmap[token]
return s
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
cfg = Config.from_cli() cfg = Config.from_cli()

View File

@ -428,6 +428,7 @@ def get_task_symmap():
"<soe>": 5, "<soe>": 5,
"<mask>": 6, "<mask>": 6,
"<eoe>": 7, "<eoe>": 7,
"<stt>": 8,
"<nse>": 6, # fake "<nse>": 6, # fake
"<cse>": 6, # fake "<cse>": 6, # fake
@ -1052,6 +1053,12 @@ class Dataset(_Dataset):
task, task,
] ]
# Base TTS (<resp> => <text>)
elif task == "stt":
# easier to just keep it instead of wrangling around trying to remove it
# it might also help to provide a guidance prompt but who knows right now
proms = self.sample_prompts(spkr_name, ignore=path)
# noise suppression (<text>? <resp+noise> => <resp>) # noise suppression (<text>? <resp+noise> => <resp>)
# speech removal (<text>?<resp+noise> => <noise>) # speech removal (<text>?<resp+noise> => <noise>)
elif task == "ns" or task == "sr": elif task == "ns" or task == "sr":

View File

@ -200,7 +200,7 @@ class AR(Base):
r = super().sample( r = super().sample(
logits=logits, logits=logits,
resps_list=resps_list, prev_list=resps_list,
temperature=sampling_temperature, temperature=sampling_temperature,
min_temperature=sampling_min_temperature, min_temperature=sampling_min_temperature,

View File

@ -61,15 +61,23 @@ class AR_NAR(Base):
disable_tqdm=False, disable_tqdm=False,
): ):
device = text_list[0].device text_task = [ "stt" ]
batch_size = len(text_list)
if text_list is not None:
default_task = "tts"
device = text_list[0].device
batch_size = len(text_list)
else:
default_task = "stt"
device = resps_list[0].device
batch_size = len(resps_list)
# generate task list if not provided # generate task list if not provided
if task_list is None: if task_list is None:
task_list = [ "tts" for _ in range(batch_size) ] task_list = [ default_task for _ in range(batch_size) ]
# is training or NAR # is training or NAR
if resps_list is not None: if resps_list is not None and text_list is not None:
n_levels_set = {r.shape[-1] for r in resps_list} n_levels_set = {r.shape[-1] for r in resps_list}
n_levels = next(iter(n_levels_set)) n_levels = next(iter(n_levels_set))
@ -102,12 +110,18 @@ class AR_NAR(Base):
# input RVQ levels # input RVQ levels
quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ] quant_levels = [ random.choice( p_rvq_levels ) for i in range(batch_size) ]
for i, task in enumerate( task_list ):
if task in text_task:
quant_levels[i] = 0 # self.n_resp_levels - 1
# trim resps to only contain all levels below the target level # trim resps to only contain all levels below the target level
resps_list = [r[..., :l+1] for r, l in zip(resps_list, quant_levels)] resps_list = [r if t in text_task else r[..., :l+1] for r, l, t in zip(resps_list, quant_levels, task_list)]
# tensor to cat for RVQ level 0 # tensor to cat for RVQ level 0
stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16) text_stop_sequence = torch.tensor([[2] * 1], device=device, dtype=torch.int16)
audio_stop_sequence = torch.tensor([[self.stop_token] * 1], device=device, dtype=torch.int16)
# I hate python's value/reference semantics so much # I hate python's value/reference semantics so much
for i, quant_level, resps, proms in zip(range(batch_size), quant_levels, resps_list, proms_list): for i, quant_level, resps, proms, task in zip(range(batch_size), quant_levels, resps_list, proms_list, task_list):
# cap quant_level if it exceeds its corresponding resp/prom # cap quant_level if it exceeds its corresponding resp/prom
if quant_level >= resps.shape[-1]: if quant_level >= resps.shape[-1]:
quant_levels[i] = resps.shape[-1] - 1 quant_levels[i] = resps.shape[-1] - 1
@ -139,7 +153,11 @@ class AR_NAR(Base):
# only apply stop token for RVQ level 0 # only apply stop token for RVQ level 0
if quant_level <= 0: if quant_level <= 0:
# append stop tokens for AR # append stop tokens for AR
resps_list[i] = torch.cat([ resps, stop_sequence ]) if task in text_task:
#text_list[i] = torch.cat([ resps, text_stop_sequence ])
...
else:
resps_list[i] = torch.cat([ resps, audio_stop_sequence ])
inputs = self.inputs( inputs = self.inputs(
@ -195,7 +213,7 @@ class AR_NAR(Base):
resps_list = super().sample( resps_list = super().sample(
logits=logits, logits=logits,
resps_list=prev_list, prev_list=prev_list,
quant_levels=quant_levels, quant_levels=quant_levels,
temperature=sampling_temperature, temperature=sampling_temperature,
@ -220,11 +238,11 @@ class AR_NAR(Base):
if cfg.lora is not None: if cfg.lora is not None:
enable_lora( self, cfg.lora.active_level( 0 ) ) enable_lora( self, cfg.lora.active_level( 0 ) )
# STT
sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ] sequence_list = [ torch.zeros(0, device=device).to(torch.int16) for _ in range(batch_size) ]
stopped = torch.zeros(batch_size, device=device).bool() stopped = torch.zeros(batch_size, device=device).bool()
stop_token = self.stop_token stop_token = self.stop_token if task_list[0] != "stt" else 2 # to-do: derive from tokenizer
state = None state = None
mirostat = [ mirostat = [
@ -233,9 +251,17 @@ class AR_NAR(Base):
scores = [ 1.0 ] * sampling_beam_width scores = [ 1.0 ] * sampling_beam_width
# add <bos> to text for STT
for i, sequence in enumerate( sequence_list ):
if task_list[i] in text_task:
sequence_list[i] = torch.cat([sequence_list[i], torch.tensor([1], dtype=torch.int16, device=device)])
# get next in sequence # get next in sequence
for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm): for n in trange(max_steps // max(1, self.causal_size), desc="AR", disable=disable_tqdm):
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list] if task_list[0] in text_task:
text_list = [x for x in sequence_list]
else:
resps_list = [x.unsqueeze(dim=-1) for x in sequence_list]
inputs = self.inputs( inputs = self.inputs(
text_list=text_list, text_list=text_list,
@ -261,7 +287,7 @@ class AR_NAR(Base):
r = super().sample( r = super().sample(
logits=logits, logits=logits,
resps_list=resps_list, prev_list=resps_list,
temperature=sampling_temperature, temperature=sampling_temperature,
min_temperature=sampling_min_temperature, min_temperature=sampling_min_temperature,
@ -398,10 +424,10 @@ def example_usage():
""" """
bos_id, space_id, eos_id = cfg.tokenizer.encode( " " ) bos_id, space_id, eos_id = cfg.tokenizer.encode( " " )
tasks = cfg.dataset.tasks_list available_tasks = ["tts", "stt"] # cfg.dataset.tasks_list
model = AR_NAR(**kwargs).to(device) model = AR_NAR(**kwargs).to(device)
steps = 150 * len(tasks) # * cfg.model.experimental.causal_size steps = 150 * len(available_tasks) # * cfg.model.experimental.causal_size
optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy" optimizer = cfg.hyperparameters.optimizer.lower() if cfg.yaml_path is not None else "prodigy"
scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else "" scheduler = cfg.hyperparameters.scheduler.lower() if cfg.yaml_path is not None else ""
@ -486,14 +512,14 @@ def example_usage():
_logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}") _logger.info(f"AR+NAR ({cfg.model.arch_type}, {cfg.audio_backend}) parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.no_grad() @torch.no_grad()
def sample_data(task=None): def sample_data(t=None):
texts = [] texts = []
proms = [] proms = []
resps = [] resps = []
tasks = []
for i in range(batch_size): for i in range(batch_size):
if task is None: task = random.choice(available_tasks) if t is None else t
task = random.choice(tasks)
text = text_list[i] text = text_list[i]
prom = proms_list[i] prom = proms_list[i]
@ -502,6 +528,8 @@ def example_usage():
# do nothing # do nothing
if task == "tts": if task == "tts":
... ...
elif task == "stt":
...
elif task == "tts-c": elif task == "tts-c":
trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second) trim_length = int(random.uniform(cfg.dataset.prompt_duration_range[0], cfg.dataset.prompt_duration_range[1]) * cfg.dataset.frames_per_second)
@ -523,25 +551,35 @@ def example_usage():
texts.append( text.to(device) ) texts.append( text.to(device) )
proms.append( prom.to(device) ) proms.append( prom.to(device) )
resps.append( resp.to(device) ) resps.append( resp.to(device) )
tasks.append( task )
return texts, proms, resps return texts, proms, resps, tasks
@torch.inference_mode() @torch.inference_mode()
def sample( name, steps=1000, task=None ): def sample( name, steps=500, task=None ):
engine.eval() engine.eval()
texts, proms, resps = sample_data( task ) texts, proms, resps, tasks = sample_data( task )
if "ar" in cfg.model.capabilities: if tasks[0] == "stt":
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 ) text = engine( None, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
"""
# to-do: STT for NAR
text = engine( text, proms, resps, task_list=tasks, max_steps=steps, sampling_temperature=0.95 )
"""
text = [ cfg.tokenizer.decode( t ) for t in text ]
print( text )
else: else:
resps = [ resp[:, 0] for resp in resps ] if "ar" in cfg.model.capabilities:
resps = engine( texts, proms, max_steps=steps, sampling_temperature=0.95 )
else:
resps = [ resp[:, 0] for resp in resps ]
if "nar" in cfg.model.capabilities: if "nar" in cfg.model.capabilities:
resps = engine( texts, proms, resps, sampling_temperature=0.2 ) resps = engine( texts, proms, resps, sampling_temperature=0.2 )
for i, o in enumerate(resps):
for i, o in enumerate(resps): _ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
_ = decode_to_file(o.to(dtype=torch.int32), f"data/{cfg.model.arch_type}.{cfg.audio_backend}.{i}.{task}.{name}.wav", device=device)
unload_model() unload_model()
@ -549,10 +587,10 @@ def example_usage():
engine.train() engine.train()
t = trange(steps) t = trange(steps)
for i in t: for i in t:
texts, proms, resps = sample_data() texts, proms, resps, tasks = sample_data()
stats = {"step": i} stats = {"step": i}
stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps) stats |= engine.traverse(text_list=texts, proms_list=proms, resps_list=resps, task_list=tasks)
stats |= {"grad_norm": engine.get_global_grad_norm()} stats |= {"grad_norm": engine.get_global_grad_norm()}
tqdm.write(f"{stats}") tqdm.write(f"{stats}")
@ -571,7 +609,7 @@ def example_usage():
model = ml.compile_model(model, backend=cfg.optimizations.compile) model = ml.compile_model(model, backend=cfg.optimizations.compile)
""" """
for task in tasks: for task in available_tasks:
sample("final", task=task) sample("final", task=task)
engines.quit() engines.quit()

View File

@ -259,7 +259,8 @@ class AudioEmbedding(nn.Module):
return x return x
# per-level classification # per-level classification
class AudioClassifier(nn.Module): # it might actually be "better" in the long run to only have one output head like a traditional LM, and just de-stitch it here instead of doing modulus math and whatever like the HF/experimental impl
class Classifiers(nn.Module):
def __init__( def __init__(
self, self,
l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token) l_tokens: list[int], # list of number of tokens (needed because AR resps includes stop token)
@ -783,10 +784,10 @@ class Base(nn.Module):
self.metrics = None self.metrics = None
else: else:
self.classifier = None self.classifier = None
self.classifiers = AudioClassifier( l_tokens, d_model ) self.classifiers = Classifiers( l_tokens + [ n_text_tokens ], d_model )
self.accuracy_metric = None self.accuracy_metric = None
self.precision_metric = None self.precision_metric = None
self.metrics = Metrics( l_tokens ) self.metrics = Metrics( l_tokens + [ n_text_tokens ] )
""" """
if tie_classifier_to_embedding: if tie_classifier_to_embedding:
@ -907,6 +908,8 @@ class Base(nn.Module):
device = text_list[0].device device = text_list[0].device
batch_size = len(text_list) batch_size = len(text_list)
special_tasks = ["stt", "len"]
inputs = [ [] for _ in range(batch_size) ] inputs = [ [] for _ in range(batch_size) ]
for i in range(batch_size): for i in range(batch_size):
quant_level = quant_levels[i] if quant_levels is not None else 0 quant_level = quant_levels[i] if quant_levels is not None else 0
@ -921,7 +924,7 @@ class Base(nn.Module):
# Base-line TTS task # Base-line TTS task
# Sequence: <text><sep><rvq lvl><sep><prom><sep><resp> # Sequence: <text><sep><rvq lvl><sep><prom><sep><resp>
# prom /may/ include <task> tokens inside to help guide things, per SpeechX # prom /may/ include <task> tokens inside to help guide things, per SpeechX
if f'<{task_type}>' in get_task_symmap(): if f'<{task_type}>' in get_task_symmap() and task_type not in special_tasks:
# insert the text prompt # insert the text prompt
if text_list is not None and text_list[i] is not None: if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) ) inputs[i].append( ( "text", text_list[i] ) )
@ -973,6 +976,21 @@ class Base(nn.Module):
elif resps_list is not None and resps_list[i] is not None: elif resps_list is not None and resps_list[i] is not None:
# yes this could be encoded better # yes this could be encoded better
inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) ) inputs[i].append( ( "len", torch.tensor([ 0 ] + [ int(i) for i in str( resps_list[i].shape[0]) ] + [ 10 ], device=device, dtype=torch.int16) ) )
# Speech-to-Text prediction task
# Sequence: <resp><sep><rvq lvl><sep><text>
elif task_type == "stt":
# insert the input response
if resps_list is not None and resps_list[i] is not None:
inputs[i].append( ( "resp", resps_list[i] ) )
# insert lang token if we're trained for it
if "lang" in self.capabilities and lang_list is not None and lang_list[i] is not None:
inputs[i].append( ( "lang", lang_list[i] ) )
# insert RVQ level guidance token if the model is versioned for it
if self.rvq_l_emb is not None and not self.interleave:
inputs[i].append( ( "quant_level", torch.tensor([ quant_level ], device=device, dtype=torch.int16) ) )
# insert the output text prompt
if text_list is not None and text_list[i] is not None:
inputs[i].append( ( "text", text_list[i] ) )
else: else:
raise Exception(f'Unrecognized task: {task_type}') raise Exception(f'Unrecognized task: {task_type}')
@ -1010,6 +1028,8 @@ class Base(nn.Module):
if not token_dropout_rvq_levels: if not token_dropout_rvq_levels:
token_dropout_rvq_levels = [1, self.resp_levels] token_dropout_rvq_levels = [1, self.resp_levels]
summed_embeddings_task = [ "stt" ]
x_list = [] x_list = []
for batch_index, batch_input in enumerate(inputs): for batch_index, batch_input in enumerate(inputs):
batch = [] batch = []
@ -1071,7 +1091,13 @@ class Base(nn.Module):
offset = 0, offset = 0,
quant_level = 0, quant_level = 0,
) )
# cheat-y way to handle performing STT across all levels
elif task_type in summed_embeddings_task:
embedding = sum([ self.resps_emb(
input[:, :l+1],
offset = 0 if l == 0 else 1, # or maybe set to 1
quant_level = l
) for l in range( input.shape[-1] - 1 ) ])
else: else:
# get RVQ level 0, or up to targetted RVQ level inference # get RVQ level 0, or up to targetted RVQ level inference
if self.version <= 4: if self.version <= 4:
@ -1171,7 +1197,9 @@ class Base(nn.Module):
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
): ):
device = logits[0].device device = logits[0].device
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ] special_tasks = [ "len", "stt" ]
summed_embeddings_task = [ "stt" ]
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
# handles tasks where the prompt has task tokens injected in the middle # handles tasks where the prompt has task tokens injected in the middle
def prompt_input_to_token( input, quant_level ): def prompt_input_to_token( input, quant_level ):
@ -1192,8 +1220,10 @@ class Base(nn.Module):
for batch_index, batch in enumerate(inputs): for batch_index, batch in enumerate(inputs):
quant_level = quant_levels[batch_index] quant_level = quant_levels[batch_index]
target = [] target = []
task_type = "tts"
for name, input in batch: for name, input in batch:
if name == "task": if name == "task":
task_type = input
task_list.append( input ) task_list.append( input )
elif name == "prom": elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input proms = [ input ] if isinstance(input, torch.Tensor) else input
@ -1201,6 +1231,8 @@ class Base(nn.Module):
elif name == "resp": elif name == "resp":
if self.interleave: if self.interleave:
target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) ) target.append( _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] ) )
elif task_type in summed_embeddings_task:
target.append( torch.full_like(input[..., 0], self.ignore_index) )
else: else:
target.append( input if input.dim() == 1 else input[:, quant_level] ) target.append( input if input.dim() == 1 else input[:, quant_level] )
elif name in ["text", "quant_level", "lang", "tone", "len"]: elif name in ["text", "quant_level", "lang", "tone", "len"]:
@ -1273,7 +1305,12 @@ class Base(nn.Module):
for name, input in batch: for name, input in batch:
# do not use resp # do not use resp
if name == "resp": if name == "resp":
input = input if input.dim() == 1 else input[:, quant_level] if self.interleave:
input = _interleave_sequence_flatten( [ input[:, l] for l in range( input.shape[-1] ) ] )
elif task_type in summed_embeddings_task:
input = torch.full_like(input[..., 0], self.ignore_index)
else:
input = input if input.dim() == 1 else input[:, quant_level]
# select prom level # select prom level
elif name == "prom": elif name == "prom":
proms = [ input ] if isinstance(input, torch.Tensor) else input proms = [ input ] if isinstance(input, torch.Tensor) else input
@ -1383,7 +1420,8 @@ class Base(nn.Module):
) )
if self.classifiers is not None: if self.classifiers is not None:
classifier_quant_levels = quant_levels if self.classifier is not None else [ -1 if inputs[i][0][-1] == "len" else l for i, l in enumerate( quant_levels ) ] special_tasks = [ "len", "stt" ]
classifier_quant_levels = [ -1 if inputs[i][0][-1] in special_tasks else l for i, l in enumerate( quant_levels ) ]
x = self.classifiers(x, levels = classifier_quant_levels) * m x = self.classifiers(x, levels = classifier_quant_levels) * m
# Remove padding # Remove padding
@ -1402,7 +1440,7 @@ class Base(nn.Module):
def sample( def sample(
self, self,
logits: list[Tensor], # logit scores logits: list[Tensor], # logit scores
resps_list: list[Tensor], # previous tokens prev_list: list[Tensor], # previous tokens
quant_levels: int | list[int] | Tensor | None = None, quant_levels: int | list[int] | Tensor | None = None,
# base sampling parameters # base sampling parameters
temperature: float = 1.0, temperature: float = 1.0,
@ -1429,7 +1467,7 @@ class Base(nn.Module):
# (NAR) return the entire generated response # (NAR) return the entire generated response
# Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously) # Parallel decoding relies on the last N tokens in the logits, because each token predicts the next RVQ layer in the same place (forgetfully obviously)
if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely if quant_levels is not None: # and "nar" in self.capabilities: # for when I get around to coping about dropping the NAR entirely
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ] logits = [ logit[-l:] for logit, l in zip(logits, map(len, prev_list)) ]
# (AR chunkwise) return the last chunkwise piece # (AR chunkwise) return the last chunkwise piece
elif self.causal: elif self.causal:
logits = [ logit[-self.causal_size:] for logit in logits ] logits = [ logit[-self.causal_size:] for logit in logits ]
@ -1439,22 +1477,22 @@ class Base(nn.Module):
# (NAR) disable stop token # (NAR) disable stop token
if quant_levels is not None and "ar" in self.capabilities: if quant_levels is not None and "ar" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, resps_list) ) ] logits = [ ban_tokens(logit, tokens=[self.stop_token]) for logit, l in zip( logits, map(len, prev_list) ) ]
# (AR-len) disable extraneous tokens # (AR-len) disable extraneous tokens
if quant_levels is None and "len" in self.capabilities: if quant_levels is None and "len" in self.capabilities:
logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, resps_list) ) ] logits = [ ban_tokens(logit, tokens=[*range(11, logit.shape[-1])]) for logit, l in zip( logits, map(len, prev_list) ) ]
# argmax instead # argmax instead
if temperature <= 0.0: if temperature <= 0.0:
return [ logit.argmax(dim=1) for logit in logits ] return [ logit.argmax(dim=1) for logit in logits ]
# perform repetition penalizing # perform repetition penalizing
if "len" not in self.capabilities: if "len" not in self.capabilities and repetition_penalty != 1.0:
logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, resps_list ) ] logits = [ reptition_penalize(logit, previous=resps[:, -1].tolist(), factor=repetition_penalty, decay=repetition_penalty_decay) for logit, resps in zip( logits, prev_list ) ]
# (AR) perform length penalizing # (AR) perform length penalizing
if quant_levels is None and self.causal: if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ] logits = [ length_penalize(logit, length=l + 1, factor=length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, prev_list) ) ]
# perform top_k/top_p filtering of our logits # perform top_k/top_p filtering of our logits
if top_k > 0 or top_p < 1.0: if top_k > 0 or top_p < 1.0:
@ -1469,7 +1507,7 @@ class Base(nn.Module):
# do DRY sampling # do DRY sampling
if dry_multiplier > 0.0: if dry_multiplier > 0.0:
logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, resps_list ) ] logits = [ dry_sampling(logit, previous=resps[:, -1].tolist(), factor=dry_multiplier, base=dry_base, allowed_length=dry_allowed_length) for logit, resps in zip( logits, prev_list ) ]
# do mirostat sampling # do mirostat sampling
# currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work # currently incompatible with beam searching with the way the two are implemented, perhaps a night of brain bashing can make the two work

View File

@ -183,7 +183,7 @@ class NAR(Base):
resps_list = super().sample( resps_list = super().sample(
logits=logits, logits=logits,
resps_list=prev_list, prev_list=prev_list,
quant_levels=quant_levels, quant_levels=quant_levels,
temperature=1.0 if n == 0 else sampling_temperature, temperature=1.0 if n == 0 else sampling_temperature,