This commit is contained in:
mrq 2025-02-23 11:55:43 -06:00
parent 3019c88799
commit b39aaacd77
2 changed files with 13 additions and 12 deletions

View File

@ -313,7 +313,7 @@ class AR_NAR(Base):
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
else:
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
# fill scores
scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
@ -337,16 +337,16 @@ class AR_NAR(Base):
# normal masking
if vc_list is None or timestep >= vc_threshold:
# mask off inputs
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
resps_list = [ resp.scatter(0, indices, self.mask_token) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ]
is_masked = [ resps == self.mask_token for resps in resps_list ]
else:
# mask off a random portion of the target
rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ]
half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
half_mask_list = [ torch.where( rand_mask, self.mask_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
# always set the last end as masked off because it causes issues
for i, mask in enumerate(half_mask_list):
half_mask_list[i][-75:] = self.stop_token
half_mask_list[i][-75:] = self.mask_token
#
# mask off inputs per mask
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ]
@ -503,7 +503,7 @@ class AR_NAR(Base):
prefix_context = sampling_kwargs.get("prefix_context", None)
# fill with masked tokens (even though they get masked anyways)
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
# fill scores
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
@ -525,9 +525,9 @@ class AR_NAR(Base):
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
# normal masking
# mask off inputs
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.stop_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
# boolean mask
is_masked = [ resps == self.stop_token for resps in resps_list ]
is_masked = [ resps == self.mask_token for resps in resps_list ]
# timestep inputs
time_list = [ timestep for _ in range(batch_size) ]

View File

@ -546,8 +546,8 @@ class Base(nn.Module):
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
self.stop_token = self.n_audio_tokens # id 1024
self.mask_token = self.n_audio_tokens + 1 # id 1024
self.stop_token = self.n_audio_tokens
self.mask_token = self.stop_token
self.causal = "ar" in self.capabilities or "len" in self.capabilities
self.version = self.config.version if self.config is not None else 5
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
@ -616,7 +616,7 @@ class Base(nn.Module):
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
else:
n_resp_tokens = n_audio_tokens + 1
n_resp_tokens = n_audio_tokens + 2
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
@ -714,6 +714,7 @@ class Base(nn.Module):
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
if self.version >= 7:
self.mask_token = self.stop_token + 1
if monolithic_audio_encoder:
self.audio_emb = AudioEncoder(
n_tokens=n_audio_tokens + 2, # stop + masked token
@ -735,7 +736,7 @@ class Base(nn.Module):
self.audio_decoder = AudioDecoder(
d_model,
d_model * 2,
(n_audio_tokens + 1) * self.n_resp_levels,
(n_audio_tokens + 2) * self.n_resp_levels,
)
if attention_backend == "auto":