some cleanup while I wait for the NAR-len to train to an acceptable state (currently it performs okay, but only on audo after 3 seconds or so)

This commit is contained in:
mrq 2024-11-09 12:12:46 -06:00
parent 69b0b3b854
commit dcd5fecff3
2 changed files with 36 additions and 17 deletions

View File

@ -5,9 +5,11 @@ The underlying model is a robust transformer, where:
* the embedded inputs are then passed through each layer of the transformer (or other model type)
* the last hidden states are then passed through the output head / classifier / projection, resulting in logit probabilities to sample from.
The beauty of a transformer, I feel, is that you can easily define any task at it, and it should follow through with it very well.
The inputs are sequenced in a way that the given task requires automatically, and the outputs are handled as per the class that extends the base model.
While the original paper called for a separate AR model and a NAR model, you can actually train a unified model for effectively free, as the internal states of the two should overlap quite a lot.
While the original paper called for a separate AR model and a NAR model, and by treating the AR and the NAR as unique tasks, you can actually train a unified model for effectively free, as the internal states of the two should overlap quite a lot.
## The AR (Autoregressive) Model
@ -49,6 +51,7 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
* embedding the current timestep is *required*, despite this technically being encoded in how many masked tokens exist within a sequence.
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
@ -69,6 +72,8 @@ The input text phonemes (or output for STT) are passed through an embedding head
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes.
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
#### Language Embedding
This embedding provides the requested language for the model to be aware of.
@ -119,6 +124,12 @@ Finally, the model *may* then sum each embedding level back down to one sequence
Additionally, it's *technically* possible to instead use the embeddings from the core model used to encode the audio, but in theory this may exclude specific features the model has encoded within the embeddings.
#### RVQ Level Embedding
This embedding hints what the target RVQ level of the audio codes is being targetted. This embedding is not required, but seems some architectures (Mamba) requires this.
This *may* replace needing separate embeddings for each RVQ level, but experimentation is required to test this claim.
### Tasks
The base model handles processing inputs into token sequences, per the requested task assigned to each input in a batch.

View File

@ -244,9 +244,11 @@ class NAR(Base):
test_artifact = None
# nasty hardcode to load a reference file and have that as the input target
# to-do: expose a way to provide the initial sequence instead through CLI
"""
if False:
path = "./data/237_134500_000036_000004.enc"
path = "./data/00_part2_success-1.enc"
test_artifact = np.load(path, allow_pickle=True)[()]
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
@ -255,7 +257,7 @@ class NAR(Base):
"""
_super = super()
def demask_sampling( seq_len, max_steps=10, temperature=0.3 ):
def demask_sampling( seq_len, max_steps=20, temperature=0.3 ):
starting_temperature = temperature
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
@ -264,23 +266,22 @@ class NAR(Base):
quant_levels = [ level for _ in range(batch_size) ]
prev_list = [ input_ids ]
noise_scale = 1.0
start_noise = 0.0
end_noise = 1.0
"""
# use hardcoded reference file to test inference capabilities
if test_artifact is not None:
# because we "set" it later on, it's not implicitly captured
nonlocal resps_list
input = resps_list[0][:, 0]
noise_scale = 1.0
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_scale else token for _, token in enumerate( input ) ], dtype=torch.int16, device=device )
print( input )
print( input_ids )
"""
start_noise = 0.0
noise_p = math.cos( start_noise * math.pi * 0.5 )
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, max_steps), reversed(range(max_steps))):
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
# anneal temperature
temperature = starting_temperature * (steps_until_x0 / max_steps)
# get noise level, per cosine scheduling
noise_p = math.cos( timestep * math.pi * 0.5 ) * noise_scale
noise_p = math.cos( timestep * math.pi * 0.5 )
# number of tokens to mask off to "noise" the input sequence
masked_tokens_n = max(int( noise_p * seq_len ), 1)
# pick the worst scoring tokens to mask off
@ -289,8 +290,7 @@ class NAR(Base):
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
# boolean mask
is_masked = input_ids == self.stop_token
# sample
sampling_top_k = math.floor( seq_len * 0.9 )
# setup inputs
resps_list = [ input_ids ]
inputs = _super.inputs(
text_list=text_list,
@ -308,6 +308,7 @@ class NAR(Base):
)
# sample with sampler settings
sampling_top_p = 0.9
filtered_sampled = _super.sample(
logits=output.logits,
prev_list=prev_list,
@ -341,14 +342,21 @@ class NAR(Base):
filtered_scores = filtered_sampled.scores[0]
unfiltered_scores = unfiltered_sampled.scores[0]
# extract sampled tokens
filtered_tokens = filtered_sampled[0][0]
unfiltered_tokens = unfiltered_sampled[0][0]
# sample with gumbelnoise
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
#sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
sampled_ids = filtered_tokens
# keep unmasked tokens
input_ids = torch.where( is_masked, sampled_ids, input_ids )
# update scores (conjugated to put the worst scores at the top)
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
return input_ids