diff --git a/vall_e/config.py b/vall_e/config.py index 6248cf8..4066677 100755 --- a/vall_e/config.py +++ b/vall_e/config.py @@ -351,7 +351,7 @@ class Model: name = [ self.name ] if isinstance(self.size, dict): - if hasattr(self.size, "label") and self.size['label']: + if self.size.get('label'): name.append(f"{self.size['label']}") elif isinstance(self.size, str) and self.size not in ["full","extended"]: name.append(self.size) @@ -374,7 +374,7 @@ class Model: @property def audio_tokens(self): - if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"): + if isinstance(self.size, dict) and "audio_tokens" in self.size: return self.size['audio_tokens'] if cfg.audio_backend == "nemo": @@ -384,19 +384,19 @@ class Model: @property def text_tokens(self): - if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"): + if isinstance(self.size, dict) and "text_tokens" in self.size: return self.size['text_tokens'] return 8575 @property def phoneme_tokens(self): - if isinstance(self.size, dict) and hasattr(self.size, "phoneme_tokens"): + if isinstance(self.size, dict) and "phoneme_tokens" in self.size: return self.size['phoneme_tokens'] return 256 @property def dim(self): - if isinstance(self.size, dict) and hasattr(self.size, "dim"): + if isinstance(self.size, dict) and "dim" in self.size: return self.size['dim'] if isinstance(self.size, float): @@ -410,7 +410,7 @@ class Model: @property def heads(self): - if isinstance(self.size, dict) and hasattr(self.size, "heads"): + if isinstance(self.size, dict) and "heads" in self.size: return self.size['heads'] if isinstance(self.size, float): @@ -424,7 +424,7 @@ class Model: @property def layers(self): - if isinstance(self.size, dict) and hasattr(self.size, "layers"): + if isinstance(self.size, dict) and "layers" in self.size: return self.size['layers'] if self.size == "double": @@ -435,7 +435,7 @@ class Model: @property def ffn(self): - if isinstance(self.size, dict) and hasattr(self.size, "ffn"): + if isinstance(self.size, dict) and "ffn" in self.size: return self.size['ffn'] return 4 diff --git a/vall_e/emb/qnt.py b/vall_e/emb/qnt.py index 734373d..458d332 100755 --- a/vall_e/emb/qnt.py +++ b/vall_e/emb/qnt.py @@ -173,7 +173,7 @@ def unload_model(): @torch.inference_mode() def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None): # dirty hack during model training - codes = torch.where( codes >= (max_token = 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes ) + codes = torch.where( codes >= ( 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes ) # upcast so it won't whine if codes.dtype in [torch.int8, torch.int16, torch.uint8]: