added option to split between text loss and audio loss (to-do: document this better), because it may or may not be a problem with LLaMA-backed models because my loss hovers around 3.9 / 56% accuracy despite sounding decent at the moment
This commit is contained in:
parent
74e531d391
commit
458b95d196
|
@ -6,41 +6,38 @@ models:
|
|||
tasks: 8
|
||||
langs: 2
|
||||
tones: 1
|
||||
arch_type: "retnet"
|
||||
arch_type: llama
|
||||
training: True
|
||||
version: 3
|
||||
version: 4
|
||||
attention: flash_attention_2
|
||||
dropout: 0.1
|
||||
|
||||
loss_factors:
|
||||
text: 0.1
|
||||
resp: 1.0
|
||||
|
||||
hyperparameters:
|
||||
batch_size: 4
|
||||
autotune: False
|
||||
autotune_params:
|
||||
start_profile_step: 1
|
||||
end_profile_step: 50
|
||||
num_tuning_micro_batch_sizes: 8
|
||||
|
||||
batch_size: 16
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_clipping: 10
|
||||
|
||||
optimizer: Adagrad
|
||||
gradient_clipping: 1.0
|
||||
warmup_steps: 100
|
||||
|
||||
optimizer: Prodigy
|
||||
learning_rate: 1.0
|
||||
torch_optimizer: True
|
||||
learning_rate: 1.0e-2
|
||||
|
||||
scheduler_type: ""
|
||||
#scheduler_type: OneCycle
|
||||
#scheduler_params:
|
||||
# cycle_first_step_size: 10_000
|
||||
# cycle_first_stair_count: 10_000
|
||||
|
||||
# cycle_second_step_size: 15_000
|
||||
# cycle_second_stair_count: 15_000
|
||||
|
||||
# decay_step_size: 5_000
|
||||
|
||||
# cycle_min_lr: 2.5e-4 # 1.0e-5
|
||||
# cycle_max_lr: 2.5e-4 # 1.0e-4
|
||||
# decay_lr_rate: 0.0
|
||||
|
||||
# cycle_min_mom: 0.90
|
||||
# cycle_max_mom: 0.99
|
||||
# decay_mom_rate: 0.0
|
||||
scheduler: "" # ScheduleFree
|
||||
torch_scheduler: True
|
||||
|
||||
evaluation:
|
||||
batch_size: 8
|
||||
frequency: 10000
|
||||
frequency: 5000
|
||||
size: 8
|
||||
|
||||
steps: 500
|
||||
|
@ -49,8 +46,9 @@ evaluation:
|
|||
load_disabled_engines: True
|
||||
|
||||
trainer:
|
||||
no_logger: True
|
||||
|
||||
#no_logger: True
|
||||
ddp: False
|
||||
#check_for_oom: False
|
||||
iterations: 1_000_000
|
||||
|
||||
save_tag: step
|
||||
|
@ -72,7 +70,7 @@ trainer:
|
|||
|
||||
gc_mode: None # "global_step"
|
||||
|
||||
weight_dtype: float32
|
||||
weight_dtype: float32 # float16 or bfloat16
|
||||
amp: False
|
||||
|
||||
backend: deepspeed
|
||||
|
@ -81,34 +79,34 @@ trainer:
|
|||
zero_optimization_level: 0
|
||||
use_compression_training: False
|
||||
|
||||
amp: False
|
||||
|
||||
activation_checkpointing: True
|
||||
|
||||
load_webui: True
|
||||
load_webui: False
|
||||
|
||||
inference:
|
||||
backend: deepspeed
|
||||
audio_backend: "dac"
|
||||
normalize: False
|
||||
|
||||
weight_dtype: float32
|
||||
weight_dtype: float32 # float16 or bfloat16
|
||||
amp: False
|
||||
|
||||
bitsandbytes:
|
||||
enabled: False
|
||||
|
||||
optimizations:
|
||||
injects: False
|
||||
replace: False
|
||||
replace: True
|
||||
|
||||
linear: False
|
||||
embedding: False
|
||||
|
||||
bitnet: False
|
||||
optimizers: True
|
||||
|
||||
fp8:
|
||||
enabled: False
|
||||
backend: "te"
|
||||
|
||||
experimental: True
|
||||
bitsandbytes: False
|
||||
dadaptation: False
|
||||
bitnet: False
|
||||
fp8: False
|
||||
|
||||
experimental: True # practically required now it seems
|
||||
|
||||
dataset:
|
||||
speaker_name_getter: "lambda p: f'{p.parts[-3]}_{p.parts[-2]}'"
|
||||
|
@ -121,23 +119,19 @@ dataset:
|
|||
hdf5_flag: r
|
||||
validate: True
|
||||
|
||||
workers: 8
|
||||
workers: 2
|
||||
cache: True
|
||||
|
||||
#phones_range: [4, 512]
|
||||
#duration_range: [1.0, 32.0]
|
||||
|
||||
phones_range: [0, 512]
|
||||
duration_range: [0.0, 64.0]
|
||||
duration_range: [3.0, 5.0]
|
||||
|
||||
random_utterance: 1.0
|
||||
max_prompts: 3
|
||||
prompt_duration: 6.0
|
||||
max_prompts: 1
|
||||
prompt_duration: 3.0
|
||||
|
||||
max_resps: 1
|
||||
p_resp_append: 0.25
|
||||
|
||||
sample_type: speaker
|
||||
sample_type: path # speaker
|
||||
|
||||
tasks_list: [ "tts" ] # , [ "tts", "tts-c", "ns", "sr", "tse", "cse", "nse", "tts"]
|
||||
|
||||
|
|
|
@ -213,9 +213,13 @@ class Model:
|
|||
attention: str = "auto"
|
||||
audio_embedding_sums: bool = True
|
||||
dropout: float = 0.1 # adjustable dropout value
|
||||
loss_factors: dict = field(default_factory=lambda: {})
|
||||
|
||||
def get(self, name=None):
|
||||
return [ self ] if not name or self.name == name else []
|
||||
|
||||
def loss_factor(self, k):
|
||||
return self.loss_factors[k] if k in self.loss_factors else 1.0
|
||||
|
||||
@property
|
||||
def max_levels(self):
|
||||
|
@ -491,6 +495,11 @@ class DeepSpeed:
|
|||
|
||||
return ds_cfg
|
||||
|
||||
@dataclass()
|
||||
class LossFactor:
|
||||
text: float = 1.0
|
||||
resp: float = 1.0
|
||||
|
||||
@dataclass()
|
||||
class Trainer:
|
||||
iterations: int = 100_000
|
||||
|
|
|
@ -348,12 +348,15 @@ def example_usage():
|
|||
|
||||
text_list = [
|
||||
tokenize("ˈaɪ wɪl nˌɑːt ˈæsk ɐ sˈɛkənd tˈaɪm").to(device),
|
||||
#tokenize("ˈaɪ wɪl nˌɑːt ˈæsk").to(device),
|
||||
]
|
||||
proms_list = [
|
||||
qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
resps_list = [
|
||||
qnt.to(device),
|
||||
qnt[:, :].to(device),
|
||||
#qnt[:cfg.dataset.frames_per_second, :].to(device),
|
||||
]
|
||||
|
||||
text_list = text_list[:1]
|
||||
|
|
|
@ -426,6 +426,11 @@ class Base(nn.Module):
|
|||
def ignore_index(self):
|
||||
return -100
|
||||
|
||||
def loss_factor(self, k):
|
||||
if self.config is None:
|
||||
return 1.0
|
||||
return self.config.loss_factors[k] if k in self.config.loss_factors else 1.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_tokens: int = 1024,
|
||||
|
@ -880,6 +885,31 @@ class Base(nn.Module):
|
|||
|
||||
return x_list
|
||||
|
||||
def training_targets_split(
|
||||
self,
|
||||
inputs: list,
|
||||
):
|
||||
text_lists = []
|
||||
resp_lists = []
|
||||
|
||||
for bi in range(len(inputs)):
|
||||
text_batch = []
|
||||
resp_batch = []
|
||||
|
||||
for i in range(len(inputs[bi])):
|
||||
name, input = inputs[bi][i]
|
||||
device = input.device
|
||||
|
||||
if name in ["text", "lang" ]:
|
||||
text_batch.append( input )
|
||||
elif name == "targ":
|
||||
resp_batch.append( input )
|
||||
|
||||
text_lists.append( _join( text_batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||
resp_lists.append( _join( resp_batch, torch.tensor(self.ignore_index, device=device) ) )
|
||||
|
||||
return text_lists, resp_lists
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: list,
|
||||
|
@ -929,30 +959,80 @@ class Base(nn.Module):
|
|||
|
||||
# compute loss if the target is given
|
||||
if training:
|
||||
target_list = self.training_targets( inputs )
|
||||
|
||||
# modify only for the AR so it can properly behave like a transformer
|
||||
for i in range(len(target_list)):
|
||||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
if not self.config.loss_factors:
|
||||
target_list = self.training_targets( inputs )
|
||||
|
||||
# modify only for the AR so it can properly behave like a transformer
|
||||
for i in range(len(target_list)):
|
||||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
|
||||
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
||||
logits[i] = logits[i][..., :-1, :] # shift the target so that token n...
|
||||
target_list[i] = target_list[i][..., 1:] # predicts token n + 1
|
||||
|
||||
target = torch.cat( target_list )
|
||||
inputs = torch.cat( logits )
|
||||
target = torch.cat( target_list )
|
||||
inputs = torch.cat( logits )
|
||||
|
||||
self.loss = dict(
|
||||
# "nll" was in the original implementation and should actually just be called something else
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
self.loss = dict(
|
||||
# "nll" was in the original implementation and should actually just be called something else
|
||||
nll = F.cross_entropy( inputs, target, ignore_index=self.ignore_index )
|
||||
)
|
||||
self.stats = dict(
|
||||
acc = self.accuracy_metric( inputs, target ),
|
||||
# precision = self.precision_metric( inputs, target ),
|
||||
)
|
||||
# split our loss
|
||||
else:
|
||||
target_text_list, target_resp_list = self.training_targets_split( inputs )
|
||||
|
||||
# grab respective slice of logits
|
||||
logits_text = [ hi[:li.shape[0]] for hi, li in zip(logits, target_text_list) ]
|
||||
logits_resp = [ hi[-li.shape[0]:] for hi, li in zip(logits, target_resp_list) ]
|
||||
|
||||
# modify only for the AR so it can properly behave like a transformer
|
||||
for i in range(len(target_text_list)):
|
||||
if quant_levels is not None and quant_levels[i] > 0:
|
||||
continue
|
||||
|
||||
# shift the target so that token n...
|
||||
logits_text[i] = logits_text[i][..., :-1, :]
|
||||
logits_resp[i] = logits_resp[i][..., :-1, :]
|
||||
|
||||
# predicts token n + 1
|
||||
target_text_list[i] = target_text_list[i][..., 1:]
|
||||
target_resp_list[i] = target_resp_list[i][..., 1:]
|
||||
|
||||
target_text = torch.cat( target_text_list ).long()
|
||||
target_resp = torch.cat( target_resp_list ).long()
|
||||
|
||||
inputs_text = torch.cat( logits_text )
|
||||
inputs_resp = torch.cat( logits_resp )
|
||||
|
||||
self.loss = dict(
|
||||
text = F.cross_entropy( inputs_text, target_text, ignore_index=self.ignore_index ),
|
||||
resp = F.cross_entropy( inputs_resp, target_resp, ignore_index=self.ignore_index ),
|
||||
)
|
||||
|
||||
for k in self.loss:
|
||||
self.loss[k] *= self.loss_factor(k)
|
||||
|
||||
# to-do: compute loss per individual batch to scale per RVQ level
|
||||
"""
|
||||
rvq_loss_factor = self.loss_factor("quant")
|
||||
if isinstance( rvq_loss_factor, list ):
|
||||
...
|
||||
"""
|
||||
|
||||
self.stats = dict(
|
||||
acc = dict(
|
||||
text = self.accuracy_metric( inputs_text, target_text ),
|
||||
resp = self.accuracy_metric( inputs_resp, target_resp ),
|
||||
),
|
||||
)
|
||||
|
||||
# include any additional losses (for example: MoE router)
|
||||
if aux_loss is not None:
|
||||
self.loss["nll"] += aux_loss
|
||||
self.loss["aux_loss"] = aux_loss
|
||||
|
||||
return (logits, state) if state is not None else logits
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user