added lots of sampling options (top-k/top-p, repetition penalty, length penalty)
This commit is contained in:
parent
f69aad9c65
commit
14c78bae39
|
@ -17,11 +17,17 @@ def main():
|
||||||
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
parser.add_argument("--max-ar-steps", type=int, default=6 * 75)
|
||||||
parser.add_argument("--ar-temp", type=float, default=1.0)
|
parser.add_argument("--ar-temp", type=float, default=1.0)
|
||||||
parser.add_argument("--nar-temp", type=float, default=1.0)
|
parser.add_argument("--nar-temp", type=float, default=1.0)
|
||||||
|
|
||||||
|
parser.add_argument("--top-p", type=float, default=1.0)
|
||||||
|
parser.add_argument("--top-k", type=int, default=0)
|
||||||
|
parser.add_argument("--repetition-penalty", type=float, default=1.0)
|
||||||
|
parser.add_argument("--length-penalty", type=float, default=0.0)
|
||||||
|
|
||||||
parser.add_argument("--device", default="cuda")
|
parser.add_argument("--device", default="cuda")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
|
tts = TTS( config=args.yaml, ar_ckpt=args.ar_ckpt, nar_ckpt=args.nar_ckpt, device=args.device )
|
||||||
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp )
|
tts.inference( text=args.text, references=args.references, out_path=args.out_path, max_ar_steps=args.max_ar_steps, ar_temp=args.ar_temp, nar_temp=args.nar_temp, top_p=args.top_p, top_k=args.top_k, repetition_penalty=args.repetition_penalty, length_penalty=args.length_penalty )
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -125,9 +125,9 @@ class TTS():
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=1.0, nar_temp=1.0, out_path=None ):
|
def inference( self, text, references, max_ar_steps=6 * 75, ar_temp=0.95, nar_temp=0.5, top_p=1.0, top_k=0, repetition_penalty=1.0, length_penalty=0.0, out_path=None ):
|
||||||
if out_path is None:
|
if out_path is None:
|
||||||
out_path = f"./data/{text}.wav"
|
out_path = f"./data/{cfg.start_time}.wav"
|
||||||
|
|
||||||
prom = self.encode_audio( references )
|
prom = self.encode_audio( references )
|
||||||
phns = self.encode_text( text )
|
phns = self.encode_text( text )
|
||||||
|
@ -136,9 +136,9 @@ class TTS():
|
||||||
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
phns = to_device(phns, self.device).to(torch.uint8 if len(self.symmap) < 256 else torch.int16)
|
||||||
|
|
||||||
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
with torch.autocast(self.device, dtype=cfg.inference.dtype, enabled=cfg.inference.amp):
|
||||||
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp)
|
resps_list = self.ar(text_list=[phns], proms_list=[prom], max_steps=max_ar_steps, sampling_temperature=ar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
|
||||||
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
resps_list = [r.unsqueeze(-1) for r in resps_list]
|
||||||
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp)
|
resps_list = self.nar(text_list=[phns], proms_list=[prom], resps_list=resps_list, sampling_temperature=nar_temp, sampling_top_p=top_p, sampling_top_k=top_k, sampling_repetition_penalty=repetition_penalty, sampling_length_penalty=length_penalty)
|
||||||
|
|
||||||
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
wav, sr = qnt.decode_to_file(resps_list[0], out_path)
|
||||||
|
|
||||||
|
|
|
@ -86,6 +86,10 @@ class AR(Base):
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
sampling_temperature: float = 1.0,
|
sampling_temperature: float = 1.0,
|
||||||
|
sampling_top_k: int = -100,
|
||||||
|
sampling_top_p: float = 1.0,
|
||||||
|
sampling_repetition_penalty: float = 1.0,
|
||||||
|
sampling_length_penalty: float = 0.0,
|
||||||
):
|
):
|
||||||
if resps_list is not None:
|
if resps_list is not None:
|
||||||
if self.interleave:
|
if self.interleave:
|
||||||
|
@ -121,6 +125,10 @@ class AR(Base):
|
||||||
resps_list=self._unsqueeze_list(resps_list),
|
resps_list=self._unsqueeze_list(resps_list),
|
||||||
quant_levels=None,
|
quant_levels=None,
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
|
sampling_top_p=sampling_top_p,
|
||||||
|
sampling_top_k=sampling_top_k,
|
||||||
|
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||||
|
sampling_length_penalty=sampling_length_penalty,
|
||||||
state=state
|
state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -193,9 +201,9 @@ def example_usage():
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model = AR(**kwargs).to(device)
|
model = AR(**kwargs).to(device)
|
||||||
|
steps = 500
|
||||||
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
optimizer = ml.Prodigy(model.parameters(), lr=1.0)
|
||||||
engine = Engine(model=model, optimizer=optimizer)
|
engine = Engine(model=model, optimizer=optimizer)
|
||||||
steps = 500
|
|
||||||
|
|
||||||
def sample( name, steps=600 ):
|
def sample( name, steps=600 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
|
|
|
@ -71,6 +71,10 @@ class AR_NAR(Base):
|
||||||
resps_list: list[Tensor] | None = None,
|
resps_list: list[Tensor] | None = None,
|
||||||
max_steps: int = 1000,
|
max_steps: int = 1000,
|
||||||
sampling_temperature: float = 0.0,
|
sampling_temperature: float = 0.0,
|
||||||
|
sampling_top_k: int = -100,
|
||||||
|
sampling_top_p: float = 1.0,
|
||||||
|
sampling_repetition_penalty: float = 1.0,
|
||||||
|
sampling_length_penalty: float = 0.0,
|
||||||
):
|
):
|
||||||
device = text_list[0].device
|
device = text_list[0].device
|
||||||
batch_size = len(text_list)
|
batch_size = len(text_list)
|
||||||
|
@ -120,6 +124,10 @@ class AR_NAR(Base):
|
||||||
prev_list,
|
prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
|
sampling_top_p=sampling_top_p,
|
||||||
|
sampling_top_k=sampling_top_k,
|
||||||
|
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||||
|
sampling_length_penalty=sampling_length_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_list = [
|
prev_list = [
|
||||||
|
@ -146,6 +154,10 @@ class AR_NAR(Base):
|
||||||
proms_list,
|
proms_list,
|
||||||
self._unsqueeze_list(resps_list),
|
self._unsqueeze_list(resps_list),
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
|
sampling_top_p=sampling_top_p,
|
||||||
|
sampling_top_k=sampling_top_k,
|
||||||
|
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||||
|
sampling_length_penalty=sampling_length_penalty,
|
||||||
state=state
|
state=state
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -221,6 +233,7 @@ def example_usage():
|
||||||
|
|
||||||
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
print(f"AR+NAR parameter count: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
def sample( name, steps=600 ):
|
def sample( name, steps=600 ):
|
||||||
engine.eval()
|
engine.eval()
|
||||||
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
resps_list = engine(text_list, proms_list, max_steps=steps, sampling_temperature=0.95 )
|
||||||
|
@ -245,7 +258,7 @@ def example_usage():
|
||||||
|
|
||||||
tqdm.write(f"{stats}")
|
tqdm.write(f"{stats}")
|
||||||
|
|
||||||
#sample("init", 75)
|
sample("init", 75)
|
||||||
train()
|
train()
|
||||||
sample("final")
|
sample("final")
|
||||||
|
|
||||||
|
|
|
@ -49,6 +49,63 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
||||||
m = m.to(x)
|
m = m.to(x)
|
||||||
return x, m
|
return x, m
|
||||||
|
|
||||||
|
# Simple filter to modify a token's probability if it shows up in the past
|
||||||
|
# To-do: have its effect decay based on distance
|
||||||
|
def reptition_penalize( logits, previous, factor=1.0 ):
|
||||||
|
if factor == 1.0:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
priors = set(previous.tolist())
|
||||||
|
|
||||||
|
for token in priors:
|
||||||
|
logits[:, token] /= factor
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# Simple "filter" that modifies the logit for the stop token, based on the sequence length
|
||||||
|
def length_penalize( logits, length, factor=0.0, token=-1 ):
|
||||||
|
if factor == 0.0:
|
||||||
|
return logits
|
||||||
|
|
||||||
|
logits[:, token] /= (length ** factor)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
# Credit to https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py#L1145 / https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||||
|
def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens=1 ):
|
||||||
|
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||||
|
Args:
|
||||||
|
logits: logits distribution shape (batch size, vocabulary size)
|
||||||
|
if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
|
||||||
|
if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
|
||||||
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||||
|
Make sure we keep at least min_tokens per batch example in the output
|
||||||
|
"""
|
||||||
|
if top_k > 0:
|
||||||
|
top_k = min(max(top_k, min_tokens), logits.size(-1)) # Safety check
|
||||||
|
# Remove all tokens with a probability less than the last token of the top-k
|
||||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
|
||||||
|
if top_p < 1.0:
|
||||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||||
|
|
||||||
|
# Remove tokens with cumulative probability above the threshold (token with 0 are kept)
|
||||||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||||||
|
if min_tokens > 1:
|
||||||
|
# Keep at least min_tokens (set to min_tokens-1 because we add the first one below)
|
||||||
|
sorted_indices_to_remove[..., :min_tokens] = 0
|
||||||
|
# Shift the indices to the right to keep also the first token above the threshold
|
||||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||||
|
sorted_indices_to_remove[..., 0] = 0
|
||||||
|
|
||||||
|
# scatter sorted tensors to original indexing
|
||||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||||
|
logits[indices_to_remove] = filter_value
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
# automagically parses a batch-list and returns it as a list
|
||||||
class Embedding(nn.Embedding):
|
class Embedding(nn.Embedding):
|
||||||
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
def forward(self, x_list: list[Tensor]) -> list[Tensor]:
|
||||||
if len(x_list) == 0:
|
if len(x_list) == 0:
|
||||||
|
@ -56,7 +113,6 @@ class Embedding(nn.Embedding):
|
||||||
|
|
||||||
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
return super().forward(torch.cat(x_list)).split([*map(len, x_list)])
|
||||||
|
|
||||||
|
|
||||||
class MultiEmbedding(nn.Embedding):
|
class MultiEmbedding(nn.Embedding):
|
||||||
"""
|
"""
|
||||||
This embedding sums embeddings on different levels.
|
This embedding sums embeddings on different levels.
|
||||||
|
@ -257,30 +313,6 @@ class Base(nn.Module):
|
||||||
ignore_index=self.ignore_index,
|
ignore_index=self.ignore_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
@overload
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
text_list: list[Tensor],
|
|
||||||
proms_list: list[Tensor],
|
|
||||||
resps_list: list[Tensor],
|
|
||||||
targ_list: list[Tensor] | None = None,
|
|
||||||
quant_levels: Tensor | None = None,
|
|
||||||
sampling_temperature: float = 1.0,
|
|
||||||
) -> Tensor:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
text_list: list[Tensor],
|
|
||||||
proms_list: list[Tensor],
|
|
||||||
resps_list: list[Tensor],
|
|
||||||
targ_list: list[Tensor] | None = None,
|
|
||||||
quant_levels: Tensor | None = None,
|
|
||||||
sampling_temperature: float = 1.0,
|
|
||||||
) -> list[Tensor]:
|
|
||||||
...
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
text_list: list[Tensor],
|
text_list: list[Tensor],
|
||||||
|
@ -290,6 +322,10 @@ class Base(nn.Module):
|
||||||
|
|
||||||
quant_levels: Tensor | None = None,
|
quant_levels: Tensor | None = None,
|
||||||
sampling_temperature: float = 1.0,
|
sampling_temperature: float = 1.0,
|
||||||
|
sampling_top_k: int = -100,
|
||||||
|
sampling_top_p: float = 1.0,
|
||||||
|
sampling_repetition_penalty: float = 1.0,
|
||||||
|
sampling_length_penalty: float = 0.0,
|
||||||
|
|
||||||
state: dict | None = None,
|
state: dict | None = None,
|
||||||
):
|
):
|
||||||
|
@ -330,13 +366,10 @@ class Base(nn.Module):
|
||||||
x = self.classifier(x) * m
|
x = self.classifier(x) * m
|
||||||
|
|
||||||
# Remove padding
|
# Remove padding
|
||||||
h_list = [hi[:li] for hi, li in zip(x, map(len, x_list))]
|
logits = [ hi[:li] for hi, li in zip(x, map(len, x_list)) ]
|
||||||
|
|
||||||
# compute loss if the target is given
|
# compute loss if the target is given
|
||||||
if targ_list is not None:
|
if targ_list is not None:
|
||||||
if any([l == 0 for l in map(len, targ_list)]):
|
|
||||||
raise ValueError("Cannot compute loss given empty targ_list.")
|
|
||||||
|
|
||||||
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
ignore_sep = torch.tensor(self.ignore_index, device=device)
|
||||||
|
|
||||||
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
# create a tensor sequence with one RVQ-bin of the input prompt, but with `ignore_index`, as the prompt is not neeeded for computing the loss against
|
||||||
|
@ -365,36 +398,49 @@ class Base(nn.Module):
|
||||||
targ_list[i][-1] = self.stop_token
|
targ_list[i][-1] = self.stop_token
|
||||||
|
|
||||||
# create the new target sequence to compute the loss against
|
# create the new target sequence to compute the loss against
|
||||||
y_list = self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep )
|
target = torch.cat( self._samplewise_merge_tensors( text_prom_list, targ_list, sep=ignore_sep ) )
|
||||||
|
inputs = torch.cat( logits )
|
||||||
|
|
||||||
self.loss = dict(
|
self.loss = dict(
|
||||||
# "nll" was in the original implementation and should actually just be called something else
|
# "nll" was in the original implementation and should actually just be called something else
|
||||||
nll=F.cross_entropy(
|
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||||
torch.cat(h_list), # input / predicted logits
|
|
||||||
torch.cat(y_list), # target / ground truth
|
|
||||||
ignore_index=self.ignore_index,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
self.stats = dict(
|
self.stats = dict(
|
||||||
acc = self.accuracy_metric( torch.cat(h_list), torch.cat(y_list) ),
|
acc = self.accuracy_metric( inputs, target ),
|
||||||
precision = self.precision_metric( torch.cat(h_list), torch.cat(y_list) ),
|
precision = self.precision_metric( inputs, target ),
|
||||||
)
|
)
|
||||||
|
|
||||||
# return the entire generated token string
|
return logits
|
||||||
return_all = False
|
|
||||||
if return_all:
|
|
||||||
logits = [hi[:] for hi, li in zip(h_list, map(len, resps_list))]
|
|
||||||
# return the entire generated response
|
|
||||||
elif quant_levels is not None:
|
|
||||||
logits = [hi[-li:] for hi, li in zip(h_list, map(len, resps_list))]
|
|
||||||
# return the last chunkwise piece
|
|
||||||
elif self.causal and self.recurrent_chunk_size > 0:
|
|
||||||
logits = [hi[-self.recurrent_chunk_size:] for hi, li in zip(h_list, map(len, resps_list))]
|
|
||||||
# return just the last code
|
|
||||||
else:
|
|
||||||
logits = [ hi[-1:] for hi in h_list ]
|
|
||||||
|
|
||||||
return [ Categorical(logits=hi / sampling_temperature).sample() for hi in logits ]
|
|
||||||
|
|
||||||
|
# (NAR) return the entire generated response
|
||||||
|
if quant_levels is not None:
|
||||||
|
logits = [ logit[-l:] for logit, l in zip(logits, map(len, resps_list)) ]
|
||||||
|
# (AR chunkwise) return the last chunkwise piece
|
||||||
|
elif self.causal and self.recurrent_chunk_size > 0:
|
||||||
|
logits = [ logit[-l:] for logit, l in zip(logits, self.recurrent_chunk_size) ]
|
||||||
|
# (AR) return just the last code
|
||||||
|
else:
|
||||||
|
logits = [ logit[-1:] for logit in logits ]
|
||||||
|
|
||||||
|
# perform repetition penalizing
|
||||||
|
logits = [ reptition_penalize(logit, previous=resps[:, 0], factor=sampling_repetition_penalty) for logit, resps in zip( logits, resps_list ) ]
|
||||||
|
|
||||||
|
# perform length penalizing
|
||||||
|
if quant_levels is None and self.causal:
|
||||||
|
logits = [ length_penalize(logit, length=l + 1, factor=sampling_length_penalty, token=self.stop_token) for logit, l in zip( logits, map(len, resps_list) ) ]
|
||||||
|
|
||||||
|
# scale our logits by the temp
|
||||||
|
logits = [ logit / sampling_temperature for logit in logits ]
|
||||||
|
|
||||||
|
# perform top_k/top_p filtering of our logits
|
||||||
|
if sampling_top_k > 0:
|
||||||
|
logits = [ top_k_top_p_filtering(logit, top_k=sampling_top_k, top_p=sampling_top_p) for logit in logits ]
|
||||||
|
|
||||||
|
# and sample
|
||||||
|
# the original implementation used this instead of argmax; it's probably placebo but it performs better than argmax
|
||||||
|
return [ Categorical(logits=logit).sample() for logit in logits ]
|
||||||
|
|
||||||
def example_usage():
|
def example_usage():
|
||||||
from ..config import cfg
|
from ..config import cfg
|
||||||
|
|
|
@ -57,6 +57,10 @@ class NAR(Base):
|
||||||
proms_list: list[Tensor],
|
proms_list: list[Tensor],
|
||||||
resps_list: list[Tensor],
|
resps_list: list[Tensor],
|
||||||
sampling_temperature: float = 0.2,
|
sampling_temperature: float = 0.2,
|
||||||
|
sampling_top_k: int = -100,
|
||||||
|
sampling_top_p: float = 1.0,
|
||||||
|
sampling_repetition_penalty: float = 1.0,
|
||||||
|
sampling_length_penalty: float = 0.0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -112,6 +116,10 @@ class NAR(Base):
|
||||||
prev_list,
|
prev_list,
|
||||||
quant_levels=quant_levels,
|
quant_levels=quant_levels,
|
||||||
sampling_temperature=sampling_temperature,
|
sampling_temperature=sampling_temperature,
|
||||||
|
sampling_top_p=sampling_top_p,
|
||||||
|
sampling_top_k=sampling_top_k,
|
||||||
|
sampling_repetition_penalty=sampling_repetition_penalty,
|
||||||
|
sampling_length_penalty=sampling_length_penalty,
|
||||||
)
|
)
|
||||||
|
|
||||||
prev_list = [
|
prev_list = [
|
||||||
|
|
Loading…
Reference in New Issue
Block a user