oops
This commit is contained in:
parent
3019c88799
commit
b39aaacd77
|
@ -313,7 +313,7 @@ class AR_NAR(Base):
|
|||
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
|
||||
else:
|
||||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
|
||||
# fill scores
|
||||
scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
|
@ -337,16 +337,16 @@ class AR_NAR(Base):
|
|||
# normal masking
|
||||
if vc_list is None or timestep >= vc_threshold:
|
||||
# mask off inputs
|
||||
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
resps_list = [ resp.scatter(0, indices, self.mask_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
||||
is_masked = [ resps == self.mask_token for resps in resps_list ]
|
||||
else:
|
||||
# mask off a random portion of the target
|
||||
rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ]
|
||||
half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
|
||||
half_mask_list = [ torch.where( rand_mask, self.mask_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
|
||||
# always set the last end as masked off because it causes issues
|
||||
for i, mask in enumerate(half_mask_list):
|
||||
half_mask_list[i][-75:] = self.stop_token
|
||||
half_mask_list[i][-75:] = self.mask_token
|
||||
#
|
||||
# mask off inputs per mask
|
||||
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ]
|
||||
|
@ -503,7 +503,7 @@ class AR_NAR(Base):
|
|||
|
||||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||
# fill with masked tokens (even though they get masked anyways)
|
||||
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
||||
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
|
||||
# fill scores
|
||||
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||
|
||||
|
@ -525,9 +525,9 @@ class AR_NAR(Base):
|
|||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||
# normal masking
|
||||
# mask off inputs
|
||||
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.stop_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||
# boolean mask
|
||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
||||
is_masked = [ resps == self.mask_token for resps in resps_list ]
|
||||
# timestep inputs
|
||||
time_list = [ timestep for _ in range(batch_size) ]
|
||||
|
||||
|
|
|
@ -546,8 +546,8 @@ class Base(nn.Module):
|
|||
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
|
||||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||
|
||||
self.stop_token = self.n_audio_tokens # id 1024
|
||||
self.mask_token = self.n_audio_tokens + 1 # id 1024
|
||||
self.stop_token = self.n_audio_tokens
|
||||
self.mask_token = self.stop_token
|
||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||
self.version = self.config.version if self.config is not None else 5
|
||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||
|
@ -616,7 +616,7 @@ class Base(nn.Module):
|
|||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||
else:
|
||||
n_resp_tokens = n_audio_tokens + 1
|
||||
n_resp_tokens = n_audio_tokens + 2
|
||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
||||
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
||||
|
@ -714,6 +714,7 @@ class Base(nn.Module):
|
|||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||
|
||||
if self.version >= 7:
|
||||
self.mask_token = self.stop_token + 1
|
||||
if monolithic_audio_encoder:
|
||||
self.audio_emb = AudioEncoder(
|
||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||
|
@ -735,7 +736,7 @@ class Base(nn.Module):
|
|||
self.audio_decoder = AudioDecoder(
|
||||
d_model,
|
||||
d_model * 2,
|
||||
(n_audio_tokens + 1) * self.n_resp_levels,
|
||||
(n_audio_tokens + 2) * self.n_resp_levels,
|
||||
)
|
||||
|
||||
if attention_backend == "auto":
|
||||
|
|
Loading…
Reference in New Issue
Block a user