some cleanup while I wait for the NAR-len to train to an acceptable state (currently it performs okay, but only on audo after 3 seconds or so)
This commit is contained in:
parent
69b0b3b854
commit
dcd5fecff3
|
@ -5,9 +5,11 @@ The underlying model is a robust transformer, where:
|
|||
* the embedded inputs are then passed through each layer of the transformer (or other model type)
|
||||
* the last hidden states are then passed through the output head / classifier / projection, resulting in logit probabilities to sample from.
|
||||
|
||||
The beauty of a transformer, I feel, is that you can easily define any task at it, and it should follow through with it very well.
|
||||
|
||||
The inputs are sequenced in a way that the given task requires automatically, and the outputs are handled as per the class that extends the base model.
|
||||
|
||||
While the original paper called for a separate AR model and a NAR model, you can actually train a unified model for effectively free, as the internal states of the two should overlap quite a lot.
|
||||
While the original paper called for a separate AR model and a NAR model, and by treating the AR and the NAR as unique tasks, you can actually train a unified model for effectively free, as the internal states of the two should overlap quite a lot.
|
||||
|
||||
## The AR (Autoregressive) Model
|
||||
|
||||
|
@ -49,6 +51,7 @@ However, having a pure NAR is challenging, as you need to both explicitly provid
|
|||
* The former problem is easily "solved" by training a `len` inferencing task, where the given input predicts the requested duration for a given utterance autoregressively.
|
||||
* The latter however proves to be challenging, as generating tokens from nothing in one step is not possible.
|
||||
* diffusion solves this, but requires additional steps at best and a separate model at worse, just for one RVQ level.
|
||||
* embedding the current timestep is *required*, despite this technically being encoded in how many masked tokens exist within a sequence.
|
||||
* the normal NAR (RVQ level 1+) does not face this problem, as it's already given a sufficient initial sequence of tokens to work with, and thus only requires one step.
|
||||
|
||||
The implemented solution follows a similar paradigm to diffusion, but with masking instead of noise.
|
||||
|
@ -69,6 +72,8 @@ The input text phonemes (or output for STT) are passed through an embedding head
|
|||
Technically, due to how the audio embeddings are implemented, it's possible to offer "language specific" embeddings, rather than one unified IPA-based embedding + a language embedding (`lang`).
|
||||
* Such an implementation *could* in fact inference from normal text rather than IPA phonemes.
|
||||
|
||||
These embeddings *could* instead be added on top of the input prompt embedding instead of serving as additional tasks (similar to injecting position embeddings), but additional experimentation is required to see if the model both can work under this and/or benefits from this.
|
||||
|
||||
#### Language Embedding
|
||||
|
||||
This embedding provides the requested language for the model to be aware of.
|
||||
|
@ -119,6 +124,12 @@ Finally, the model *may* then sum each embedding level back down to one sequence
|
|||
|
||||
Additionally, it's *technically* possible to instead use the embeddings from the core model used to encode the audio, but in theory this may exclude specific features the model has encoded within the embeddings.
|
||||
|
||||
#### RVQ Level Embedding
|
||||
|
||||
This embedding hints what the target RVQ level of the audio codes is being targetted. This embedding is not required, but seems some architectures (Mamba) requires this.
|
||||
|
||||
This *may* replace needing separate embeddings for each RVQ level, but experimentation is required to test this claim.
|
||||
|
||||
### Tasks
|
||||
|
||||
The base model handles processing inputs into token sequences, per the requested task assigned to each input in a batch.
|
||||
|
|
|
@ -244,9 +244,11 @@ class NAR(Base):
|
|||
|
||||
test_artifact = None
|
||||
|
||||
# nasty hardcode to load a reference file and have that as the input target
|
||||
# to-do: expose a way to provide the initial sequence instead through CLI
|
||||
"""
|
||||
if False:
|
||||
path = "./data/237_134500_000036_000004.enc"
|
||||
path = "./data/00_part2_success-1.enc"
|
||||
test_artifact = np.load(path, allow_pickle=True)[()]
|
||||
text_list = [ torch.tensor( cfg.tokenizer.encode( test_artifact["metadata"]["phonemes"] ) ).to(dtype=torch.uint8, device=device) ]
|
||||
resps_list = [ torch.from_numpy(test_artifact["codes"].astype(np.int16))[0, :, :].t().to(dtype=torch.int16, device=device) ]
|
||||
|
@ -255,7 +257,7 @@ class NAR(Base):
|
|||
"""
|
||||
|
||||
_super = super()
|
||||
def demask_sampling( seq_len, max_steps=10, temperature=0.3 ):
|
||||
def demask_sampling( seq_len, max_steps=20, temperature=0.3 ):
|
||||
starting_temperature = temperature
|
||||
|
||||
input_ids = torch.ones((seq_len,), dtype=torch.long, device=device) * self.stop_token
|
||||
|
@ -264,23 +266,22 @@ class NAR(Base):
|
|||
quant_levels = [ level for _ in range(batch_size) ]
|
||||
prev_list = [ input_ids ]
|
||||
|
||||
noise_scale = 1.0
|
||||
start_noise = 0.0
|
||||
end_noise = 1.0
|
||||
|
||||
"""
|
||||
# use hardcoded reference file to test inference capabilities
|
||||
if test_artifact is not None:
|
||||
# because we "set" it later on, it's not implicitly captured
|
||||
nonlocal resps_list
|
||||
input = resps_list[0][:, 0]
|
||||
noise_scale = 1.0
|
||||
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_scale else token for _, token in enumerate( input ) ], dtype=torch.int16, device=device )
|
||||
print( input )
|
||||
print( input_ids )
|
||||
"""
|
||||
start_noise = 0.0
|
||||
noise_p = math.cos( start_noise * math.pi * 0.5 )
|
||||
input_ids = torch.tensor( [ self.stop_token if random.random() < noise_p else token for _, token in enumerate( resps_list[0][:, 0] ) ], dtype=torch.int16, device=device )
|
||||
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(0, 1, max_steps), reversed(range(max_steps))):
|
||||
for timestep, steps_until_x0 in zip(torch.linspace(start_noise, end_noise, max_steps), reversed(range(max_steps))):
|
||||
# anneal temperature
|
||||
temperature = starting_temperature * (steps_until_x0 / max_steps)
|
||||
# get noise level, per cosine scheduling
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 ) * noise_scale
|
||||
noise_p = math.cos( timestep * math.pi * 0.5 )
|
||||
# number of tokens to mask off to "noise" the input sequence
|
||||
masked_tokens_n = max(int( noise_p * seq_len ), 1)
|
||||
# pick the worst scoring tokens to mask off
|
||||
|
@ -289,8 +290,7 @@ class NAR(Base):
|
|||
input_ids = input_ids.scatter(0, masked_indices, self.stop_token)
|
||||
# boolean mask
|
||||
is_masked = input_ids == self.stop_token
|
||||
# sample
|
||||
sampling_top_k = math.floor( seq_len * 0.9 )
|
||||
# setup inputs
|
||||
resps_list = [ input_ids ]
|
||||
inputs = _super.inputs(
|
||||
text_list=text_list,
|
||||
|
@ -308,6 +308,7 @@ class NAR(Base):
|
|||
)
|
||||
|
||||
# sample with sampler settings
|
||||
sampling_top_p = 0.9
|
||||
filtered_sampled = _super.sample(
|
||||
logits=output.logits,
|
||||
prev_list=prev_list,
|
||||
|
@ -341,14 +342,21 @@ class NAR(Base):
|
|||
filtered_scores = filtered_sampled.scores[0]
|
||||
unfiltered_scores = unfiltered_sampled.scores[0]
|
||||
|
||||
# extract sampled tokens
|
||||
filtered_tokens = filtered_sampled[0][0]
|
||||
unfiltered_tokens = unfiltered_sampled[0][0]
|
||||
|
||||
# sample with gumbelnoise
|
||||
sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||
# I actually feel like this doesn't matter? it's hard to judge with a partially trained NAR-len model
|
||||
#sampled_ids = gumbel_sample( filtered_logits, temperature=temperature, dim=-1 )
|
||||
sampled_ids = filtered_tokens
|
||||
|
||||
# keep unmasked tokens
|
||||
input_ids = torch.where( is_masked, sampled_ids, input_ids )
|
||||
# update scores (conjugated to put the worst scores at the top)
|
||||
scores = 1.0 - torch.tensor([score for score in unfiltered_scores], device=device)
|
||||
|
||||
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||
# print( timestep, steps_until_x0, noise_p, masked_tokens_n, input_ids, scores )
|
||||
|
||||
return input_ids
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user