diff --git a/vall_e/models/ar_nar.py b/vall_e/models/ar_nar.py index 9d6f76e..75156f9 100644 --- a/vall_e/models/ar_nar.py +++ b/vall_e/models/ar_nar.py @@ -313,7 +313,7 @@ class AR_NAR(Base): scores = [ torch.tensor( [ 1.0 if random.random() < noise_p else 0.0 for _ in range( seq_len ) ], dtype=torch.float32, device=device ) for seq_len in len_list ] else: # fill with masked tokens (even though they get masked anyways) - resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ] + resps_list = [ torch.ones((seq_len,), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ] # fill scores scores = [ torch.ones((seq_len,), dtype=torch.float32, device=device) for seq_len in len_list ] @@ -337,16 +337,16 @@ class AR_NAR(Base): # normal masking if vc_list is None or timestep >= vc_threshold: # mask off inputs - resps_list = [ resp.scatter(0, indices, self.stop_token) for resp, indices in zip( resps_list, masked_indices ) ] + resps_list = [ resp.scatter(0, indices, self.mask_token) for resp, indices in zip( resps_list, masked_indices ) ] # boolean mask - is_masked = [ resps == self.stop_token for resps in resps_list ] + is_masked = [ resps == self.mask_token for resps in resps_list ] else: # mask off a random portion of the target rand_mask_list = [ torch.rand(mask.shape).to(device=device) < vc_mask_p for mask in vc_list ] - half_mask_list = [ torch.where( rand_mask, self.stop_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ] + half_mask_list = [ torch.where( rand_mask, self.mask_token, mask.clone() ) for mask, rand_mask in zip( vc_list, rand_mask_list ) ] # always set the last end as masked off because it causes issues for i, mask in enumerate(half_mask_list): - half_mask_list[i][-75:] = self.stop_token + half_mask_list[i][-75:] = self.mask_token # # mask off inputs per mask resps_list = [ resp.scatter(0, indices, mask) for resp, indices, mask in zip( resps_list, masked_indices, half_mask_list ) ] @@ -503,7 +503,7 @@ class AR_NAR(Base): prefix_context = sampling_kwargs.get("prefix_context", None) # fill with masked tokens (even though they get masked anyways) - resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.stop_token for seq_len in len_list ] + resps_list = [ torch.ones((seq_len, self.n_resp_levels), dtype=torch.int16, device=device) * self.mask_token for seq_len in len_list ] # fill scores scores = [ torch.ones((seq_len), dtype=torch.float32, device=device) for seq_len in len_list ] @@ -525,9 +525,9 @@ class AR_NAR(Base): masked_indices = [ score.topk( clamp( int( noise_p * seq_len + remask_p * seq_len ), 1, seq_len), dim=-1 ).indices for score, seq_len in zip(scores, len_list) ] # normal masking # mask off inputs - resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.stop_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ] + resps_list = [ torch.stack([resp[:, l].scatter(0, indices, self.mask_token) for l in range(self.n_resp_levels)], dim=-1) for resp, indices in zip( resps_list, masked_indices ) ] # boolean mask - is_masked = [ resps == self.stop_token for resps in resps_list ] + is_masked = [ resps == self.mask_token for resps in resps_list ] # timestep inputs time_list = [ timestep for _ in range(batch_size) ] diff --git a/vall_e/models/base.py b/vall_e/models/base.py index c4ec0c5..4c8217e 100755 --- a/vall_e/models/base.py +++ b/vall_e/models/base.py @@ -546,8 +546,8 @@ class Base(nn.Module): self.capabilities = self.config.capabilities if self.config else ["ar", "nar"] self.gradient_checkpointing = self.config.gradient_checkpointing if self.config is not None else True - self.stop_token = self.n_audio_tokens # id 1024 - self.mask_token = self.n_audio_tokens + 1 # id 1024 + self.stop_token = self.n_audio_tokens + self.mask_token = self.stop_token self.causal = "ar" in self.capabilities or "len" in self.capabilities self.version = self.config.version if self.config is not None else 5 self.causal_size = self.config.experimental.causal_size if self.config is not None else (1 if self.causal else 0) @@ -616,7 +616,7 @@ class Base(nn.Module): l_embedding_names = ['AR:0:0'] + [f'NAR:{i}:{i+1}' for i in range( self.n_resp_levels - 1 )] l_classifier_tokens = [n_resp_tokens] + [n_resp_tokens - 1] * (self.n_resp_levels - 1) else: - n_resp_tokens = n_audio_tokens + 1 + n_resp_tokens = n_audio_tokens + 2 l_embedding_tokens = [n_resp_tokens] * self.n_resp_levels l_embedding_names = [] # [f'NAR:{i}' for i in range( self.n_resp_levels )] l_classifier_tokens = [] # [n_audio_tokens] * self.n_resp_levels @@ -714,6 +714,7 @@ class Base(nn.Module): self.raw_text_emb = Embedding(self.n_raw_text_tokens, d_model) if self.version >= 7: + self.mask_token = self.stop_token + 1 if monolithic_audio_encoder: self.audio_emb = AudioEncoder( n_tokens=n_audio_tokens + 2, # stop + masked token @@ -735,7 +736,7 @@ class Base(nn.Module): self.audio_decoder = AudioDecoder( d_model, d_model * 2, - (n_audio_tokens + 1) * self.n_resp_levels, + (n_audio_tokens + 2) * self.n_resp_levels, ) if attention_backend == "auto":