could have sworn this worked before, might have broke it when i decoupled from omegaconf
This commit is contained in:
parent
17094b8002
commit
1d3290b023
|
@ -351,7 +351,7 @@ class Model:
|
|||
name = [ self.name ]
|
||||
|
||||
if isinstance(self.size, dict):
|
||||
if hasattr(self.size, "label") and self.size['label']:
|
||||
if self.size.get('label'):
|
||||
name.append(f"{self.size['label']}")
|
||||
elif isinstance(self.size, str) and self.size not in ["full","extended"]:
|
||||
name.append(self.size)
|
||||
|
@ -374,7 +374,7 @@ class Model:
|
|||
|
||||
@property
|
||||
def audio_tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "audio_tokens"):
|
||||
if isinstance(self.size, dict) and "audio_tokens" in self.size:
|
||||
return self.size['audio_tokens']
|
||||
|
||||
if cfg.audio_backend == "nemo":
|
||||
|
@ -384,19 +384,19 @@ class Model:
|
|||
|
||||
@property
|
||||
def text_tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "text_tokens"):
|
||||
if isinstance(self.size, dict) and "text_tokens" in self.size:
|
||||
return self.size['text_tokens']
|
||||
return 8575
|
||||
|
||||
@property
|
||||
def phoneme_tokens(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "phoneme_tokens"):
|
||||
if isinstance(self.size, dict) and "phoneme_tokens" in self.size:
|
||||
return self.size['phoneme_tokens']
|
||||
return 256
|
||||
|
||||
@property
|
||||
def dim(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "dim"):
|
||||
if isinstance(self.size, dict) and "dim" in self.size:
|
||||
return self.size['dim']
|
||||
|
||||
if isinstance(self.size, float):
|
||||
|
@ -410,7 +410,7 @@ class Model:
|
|||
|
||||
@property
|
||||
def heads(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "heads"):
|
||||
if isinstance(self.size, dict) and "heads" in self.size:
|
||||
return self.size['heads']
|
||||
|
||||
if isinstance(self.size, float):
|
||||
|
@ -424,7 +424,7 @@ class Model:
|
|||
|
||||
@property
|
||||
def layers(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "layers"):
|
||||
if isinstance(self.size, dict) and "layers" in self.size:
|
||||
return self.size['layers']
|
||||
|
||||
if self.size == "double":
|
||||
|
@ -435,7 +435,7 @@ class Model:
|
|||
|
||||
@property
|
||||
def ffn(self):
|
||||
if isinstance(self.size, dict) and hasattr(self.size, "ffn"):
|
||||
if isinstance(self.size, dict) and "ffn" in self.size:
|
||||
return self.size['ffn']
|
||||
|
||||
return 4
|
||||
|
|
|
@ -173,7 +173,7 @@ def unload_model():
|
|||
@torch.inference_mode()
|
||||
def decode(codes: Tensor, device="cuda", dtype=None, metadata=None, window_duration=None):
|
||||
# dirty hack during model training
|
||||
codes = torch.where( codes >= (max_token = 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes )
|
||||
codes = torch.where( codes >= ( 1000 if cfg.audio_backend == "nemo" else 1024 ), 0, codes )
|
||||
|
||||
# upcast so it won't whine
|
||||
if codes.dtype in [torch.int8, torch.int16, torch.uint8]:
|
||||
|
|
Loading…
Reference in New Issue
Block a user