added lots of sampling options (top-k/top-p, repetition penalty, length penalty)

This commit is contained in:
mrq 2023-09-08 20:30:54 -05:00
parent f69aad9c65
commit 14c78bae39
6 changed files with 139 additions and 58 deletions

View File

@ -17,11 +17,17 @@ def main():
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
parser.add_argument("--ar-temp", type=float, default=1.0)
parser.add_argument("--nar-temp", type=float, default=1.0)
parser.add_argument("--top-p", type=float, default=1.0)
parser.add_argument("--top-k", type=int, default=0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--length-penalty", type=float, default=0.0)
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, length_penalty=args.length_penalty )
if __name__ == "__main__":
main()

View File

@ -125,9 +125,9 @@ class TTS():
return res
@torch.inference_mode()
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path=None ):
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, length_penalty=0.0, out_path=None ):
if out_path is None:
out_path = f"./data/{text}.wav"
out_path = f"./data/{cfg.start_time}.wav"
prom = self.encode_audio( references )
phns = self.encode_text( text )
@ -136,9 +136,9 @@ class TTS():
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
resps_list = [r.unsqueeze(-1) for r in resps_list]
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
wav, sr = qnt.decode_to_file(resps_list[0], out_path)

View File

@ -86,6 +86,10 @@ class AR(Base):
resps_list: list[Tensor] | None = None,
max_steps: int = 1000,
sampling_temperature: float = 1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_length_penalty: float = 0.0,
):
if resps_list is not None:
if self.interleave:
@ -121,6 +125,10 @@ class AR(Base):
resps_list=self._unsqueeze_list(resps_list),
quant_levels=None,
sampling_temperature=sampling_temperature,
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_length_penalty=sampling_length_penalty,
state=state
)
@ -193,9 +201,9 @@ def example_usage():
"""
model = AR(**kwargs).to(device)
steps = 500
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
engine = Engine(model=model, optimizer=optimizer)
steps = 500
def sample( name, steps=600 ):
engine.eval()

View File

@ -71,6 +71,10 @@ class AR_NAR(Base):
resps_list: list[Tensor] | None = None,
max_steps: int = 1000,
sampling_temperature: float = 0.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_length_penalty: float = 0.0,
):
device = text_list[0].device
batch_size = len(text_list)
@ -120,6 +124,10 @@ class AR_NAR(Base):
prev_list,
quant_levels=quant_levels,
sampling_temperature=sampling_temperature,
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_length_penalty=sampling_length_penalty,
)
prev_list = [
@ -146,6 +154,10 @@ class AR_NAR(Base):
proms_list,
self._unsqueeze_list(resps_list),
sampling_temperature=sampling_temperature,
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_length_penalty=sampling_length_penalty,
state=state
)
@ -221,6 +233,7 @@ def example_usage():
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
@torch.inference_mode()
def sample( name, steps=600 ):
engine.eval()
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
@ -245,7 +258,7 @@ def example_usage():
tqdm.write(f"{stats}")
#sample("init", 75)
sample("init", 75)
train()
sample("final")

View File

@ -49,6 +49,63 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
m = m.to(x)
return x, m
# Simple filter to modify a token's probability if it shows up in the past
# To-do: have its effect decay based on distance
def reptition_penalize( logits, previous, factor=1.0 ):
if factor == 1.0:
return logits
priors = set(previous.tolist())
for token in priors:
logits[:, token] /= factor
return logits
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
def length_penalize( logits, length, factor=0.0, token=-1 ):
if factor == 0.0:
return logits
logits[:, token] /= (length ** factor)
return logits
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
logits: logits distribution shape (batch size, vocabulary size)
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
Make sure we keep at least min_tokens per batch example in the output
"""
if top_k > 0:
top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens > 1:
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
sorted_indices_to_remove[..., :min_tokens] = 0
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
# automagically parses a batch-list and returns it as a list
class Embedding(nn.Embedding):
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
if len(x_list) == 0:
@ -56,7 +113,6 @@ class Embedding(nn.Embedding):
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
class MultiEmbedding(nn.Embedding):
"""
This embedding sums embeddings on different levels.
@ -257,30 +313,6 @@ class Base(nn.Module):
ignore_index=self.ignore_index,
)
@overload
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
sampling_temperature: float = 1.0,
) -> Tensor:
...
@overload
def forward(
self,
text_list: list[Tensor],
proms_list: list[Tensor],
resps_list: list[Tensor],
targ_list: list[Tensor] | None = None,
quant_levels: Tensor | None = None,
sampling_temperature: float = 1.0,
) -> list[Tensor]:
...
def forward(
self,
text_list: list[Tensor],
@ -290,6 +322,10 @@ class Base(nn.Module):
quant_levels: Tensor | None = None,
sampling_temperature: float = 1.0,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_length_penalty: float = 0.0,
state: dict | None = None,
):
@ -330,13 +366,10 @@ class Base(nn.Module):
x = self.classifier(x) * m
# Remove padding
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
# compute loss if the target is given
if targ_list is not None:
if any([l == 0 for l in map(len, targ_list)]):
raise ValueError("Cannot compute loss given empty targ_list.")
ignore_sep = torch.tensor(self.ignore_index, device=device)
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
@ -365,36 +398,49 @@ class Base(nn.Module):
targ_list[i][-1] = self.stop_token
# create the new target sequence to compute the loss against
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )
inputs = torch.cat( logits )
self.loss = dict(
# "nll" was in the original implementation and should actually just be called something else
nll=F.cross_entropy(
torch.cat(h_list), # input / predicted logits
torch.cat(y_list), # target / ground truth
ignore_index=self.ignore_index,
)
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
)
self.stats = dict(
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
acc = self.accuracy_metric( inputs, target ),
precision = self.precision_metric( inputs, target ),
)
return logits
# return the entire generated token string
return_all = False
if return_all:
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
# return the entire generated response
elif quant_levels is not None:
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
# return the last chunkwise piece
elif self.causal and self.recurrent_chunk_size > 0:
logits = [hi[-self.recurrent_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
# return just the last code
else:
logits = [ hi[-1:] for hi in h_list ]
return [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
# (NAR) return the entire generated response
if quant_levels is not None:
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
# (AR chunkwise) return the last chunkwise piece
elif self.causal and self.recurrent_chunk_size > 0:
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
# (AR) return just the last code
else:
logits = [ logit[-1:] for logit in logits ]
# perform repetition penalizing
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ]
# perform length penalizing
if quant_levels is None and self.causal:
logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
# scale our logits by the temp
logits = [ logit / sampling_temperature for logit in logits ]
# perform top_k/top_p filtering of our logits
if sampling_top_k > 0:
logits = [ top_k_top_p_filtering(logit, top_k=sampling_top_k, top_p=sampling_top_p) for logit in logits ]
# and sample
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
return [ Categorical(logits=logit).sample() for logit in logits ]
def example_usage():
from ..config import cfg

View File

@ -57,6 +57,10 @@ class NAR(Base):
proms_list: list[Tensor],
resps_list: list[Tensor],
sampling_temperature: float = 0.2,
sampling_top_k: int = -100,
sampling_top_p: float = 1.0,
sampling_repetition_penalty: float = 1.0,
sampling_length_penalty: float = 0.0,
):
"""
Args:
@ -112,6 +116,10 @@ class NAR(Base):
prev_list,
quant_levels=quant_levels,
sampling_temperature=sampling_temperature,
sampling_top_p=sampling_top_p,
sampling_top_k=sampling_top_k,
sampling_repetition_penalty=sampling_repetition_penalty,
sampling_length_penalty=sampling_length_penalty,
)
prev_list = [