could have sworn this worked before, might have broke it when i decoupled from omegaconf

This commit is contained in:
mrq 2025-03-01 19:30:26 -06:00
parent 17094b8002
commit 1d3290b023
2 changed files with 9 additions and 9 deletions

View File

@ -351,7 +351,7 @@ class Model:
name = [ self.name ] name = [ self.name ]
if isinstance(self.size, dict): if isinstance(self.size, dict):
if hasattr(self.size, "label") and self.size['label']: if self.size.get('label'):
name.append(f"{self.size['label']}") name.append(f"{self.size['label']}")
elif isinstance(self.size, str) and self.size not in ["full","extended"]: elif isinstance(self.size, str) and self.size not in ["full","extended"]:
name.append(self.size) name.append(self.size)
@ -374,7 +374,7 @@ class Model:
@property @property
def audio_tokens(self): def audio_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"): if isinstance(self.size, dict) and "audio_tokens" in self.size:
return self.size['audio_tokens'] return self.size['audio_tokens']
if cfg.audio_backend == "nemo": if cfg.audio_backend == "nemo":
@ -384,19 +384,19 @@ class Model:
@property @property
def text_tokens(self): def text_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"): if isinstance(self.size, dict) and "text_tokens" in self.size:
return self.size['text_tokens'] return self.size['text_tokens']
return 8575 return 8575
@property @property
def phoneme_tokens(self): def phoneme_tokens(self):
if isinstance(self.size, dict) and hasattr(self.size, "phoneme_tokens"): if isinstance(self.size, dict) and "phoneme_tokens" in self.size:
return self.size['phoneme_tokens'] return self.size['phoneme_tokens']
return 256 return 256
@property @property
def dim(self): def dim(self):
if isinstance(self.size, dict) and hasattr(self.size, "dim"): if isinstance(self.size, dict) and "dim" in self.size:
return self.size['dim'] return self.size['dim']
if isinstance(self.size, float): if isinstance(self.size, float):
@ -410,7 +410,7 @@ class Model:
@property @property
def heads(self): def heads(self):
if isinstance(self.size, dict) and hasattr(self.size, "heads"): if isinstance(self.size, dict) and "heads" in self.size:
return self.size['heads'] return self.size['heads']
if isinstance(self.size, float): if isinstance(self.size, float):
@ -424,7 +424,7 @@ class Model:
@property @property
def layers(self): def layers(self):
if isinstance(self.size, dict) and hasattr(self.size, "layers"): if isinstance(self.size, dict) and "layers" in self.size:
return self.size['layers'] return self.size['layers']
if self.size == "double": if self.size == "double":
@ -435,7 +435,7 @@ class Model:
@property @property
def ffn(self): def ffn(self):
if isinstance(self.size, dict) and hasattr(self.size, "ffn"): if isinstance(self.size, dict) and "ffn" in self.size:
return self.size['ffn'] return self.size['ffn']
return 4 return 4

View File

@ -173,7 +173,7 @@ def unload_model():
@torch.inference_mode() @torch.inference_mode()
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None): def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
# dirty hack during model training # dirty hack during model training
codes = torch.where( codes >= (max_token = 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes ) codes = torch.where( codes >= ( 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes )
# upcast so it won't whine # upcast so it won't whine
if codes.dtype in [torch.int8, torch.int16, torch.uint8]: if codes.dtype in [torch.int8, torch.int16, torch.uint8]: