oops
This commit is contained in:
parent
3019c88799
commit
b39aaacd77
|
@ -313,7 +313,7 @@ class AR_NAR(Base):
|
||||||
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
|
scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ]
|
||||||
else:
|
else:
|
||||||
# fill with masked tokens (even though they get masked anyways)
|
# fill with masked tokens (even though they get masked anyways)
|
||||||
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
|
||||||
# fill scores
|
# fill scores
|
||||||
scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||||
|
|
||||||
|
@ -337,16 +337,16 @@ class AR_NAR(Base):
|
||||||
# normal masking
|
# normal masking
|
||||||
if vc_list is None or timestep >= vc_threshold:
|
if vc_list is None or timestep >= vc_threshold:
|
||||||
# mask off inputs
|
# mask off inputs
|
||||||
resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
resps_list = [ resp.scatter(0, indices, self.mask_token) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||||
# boolean mask
|
# boolean mask
|
||||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
is_masked = [ resps == self.mask_token for resps in resps_list ]
|
||||||
else:
|
else:
|
||||||
# mask off a random portion of the target
|
# mask off a random portion of the target
|
||||||
rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ]
|
rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ]
|
||||||
half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
|
half_mask_list = [ torch.where( rand_mask, self.mask_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ]
|
||||||
# always set the last end as masked off because it causes issues
|
# always set the last end as masked off because it causes issues
|
||||||
for i, mask in enumerate(half_mask_list):
|
for i, mask in enumerate(half_mask_list):
|
||||||
half_mask_list[i][-75:] = self.stop_token
|
half_mask_list[i][-75:] = self.mask_token
|
||||||
#
|
#
|
||||||
# mask off inputs per mask
|
# mask off inputs per mask
|
||||||
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ]
|
resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ]
|
||||||
|
@ -503,7 +503,7 @@ class AR_NAR(Base):
|
||||||
|
|
||||||
prefix_context = sampling_kwargs.get("prefix_context", None)
|
prefix_context = sampling_kwargs.get("prefix_context", None)
|
||||||
# fill with masked tokens (even though they get masked anyways)
|
# fill with masked tokens (even though they get masked anyways)
|
||||||
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ]
|
resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ]
|
||||||
# fill scores
|
# fill scores
|
||||||
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
|
scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ]
|
||||||
|
|
||||||
|
@ -525,9 +525,9 @@ class AR_NAR(Base):
|
||||||
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ]
|
||||||
# normal masking
|
# normal masking
|
||||||
# mask off inputs
|
# mask off inputs
|
||||||
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.stop_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ]
|
||||||
# boolean mask
|
# boolean mask
|
||||||
is_masked = [ resps == self.stop_token for resps in resps_list ]
|
is_masked = [ resps == self.mask_token for resps in resps_list ]
|
||||||
# timestep inputs
|
# timestep inputs
|
||||||
time_list = [ timestep for _ in range(batch_size) ]
|
time_list = [ timestep for _ in range(batch_size) ]
|
||||||
|
|
||||||
|
|
|
@ -546,8 +546,8 @@ class Base(nn.Module):
|
||||||
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
|
self.capabilities = self.config.capabilities if self.config else ["ar", "nar"]
|
||||||
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True
|
||||||
|
|
||||||
self.stop_token = self.n_audio_tokens # id 1024
|
self.stop_token = self.n_audio_tokens
|
||||||
self.mask_token = self.n_audio_tokens + 1 # id 1024
|
self.mask_token = self.stop_token
|
||||||
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
self.causal = "ar" in self.capabilities or "len" in self.capabilities
|
||||||
self.version = self.config.version if self.config is not None else 5
|
self.version = self.config.version if self.config is not None else 5
|
||||||
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0)
|
||||||
|
@ -616,7 +616,7 @@ class Base(nn.Module):
|
||||||
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )]
|
||||||
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1)
|
||||||
else:
|
else:
|
||||||
n_resp_tokens = n_audio_tokens + 1
|
n_resp_tokens = n_audio_tokens + 2
|
||||||
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels
|
||||||
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )]
|
||||||
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels
|
||||||
|
@ -714,6 +714,7 @@ class Base(nn.Module):
|
||||||
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model)
|
||||||
|
|
||||||
if self.version >= 7:
|
if self.version >= 7:
|
||||||
|
self.mask_token = self.stop_token + 1
|
||||||
if monolithic_audio_encoder:
|
if monolithic_audio_encoder:
|
||||||
self.audio_emb = AudioEncoder(
|
self.audio_emb = AudioEncoder(
|
||||||
n_tokens=n_audio_tokens + 2, # stop + masked token
|
n_tokens=n_audio_tokens + 2, # stop + masked token
|
||||||
|
@ -735,7 +736,7 @@ class Base(nn.Module):
|
||||||
self.audio_decoder = AudioDecoder(
|
self.audio_decoder = AudioDecoder(
|
||||||
d_model,
|
d_model,
|
||||||
d_model * 2,
|
d_model * 2,
|
||||||
(n_audio_tokens + 1) * self.n_resp_levels,
|
(n_audio_tokens + 2) * self.n_resp_levels,
|
||||||
)
|
)
|
||||||
|
|
||||||
if attention_backend == "auto":
|
if attention_backend == "auto":
|
||||||
|
|
Loading…
Reference in New Issue
Block a user