Support fp16
This commit is contained in:
parent
4cb958f7ff
commit
be4619cddf
@ -44,16 +44,27 @@ class Config(ConfigBase):
|
||||
min_phones: int = 10
|
||||
max_phones: int = 50
|
||||
|
||||
use_fp16: bool = True
|
||||
|
||||
@cached_property
|
||||
def get_spkr(self):
|
||||
return eval(self.spkr_name_getter)
|
||||
|
||||
@property
|
||||
def fp16_cfg(self):
|
||||
return {
|
||||
"enabled": self.use_fp16,
|
||||
}
|
||||
|
||||
@property
|
||||
def ds_cfg(self):
|
||||
return {
|
||||
"train_micro_batch_size_per_gpu": self.batch_size,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"optimizer": {"type": "Adam"},
|
||||
"optimizer": {
|
||||
"type": "Adam",
|
||||
"lr": self.warmup_min_lr,
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
@ -65,6 +76,7 @@ class Config(ConfigBase):
|
||||
},
|
||||
},
|
||||
"gradient_clipping": self.gradient_clipping,
|
||||
"fp16": self.fp16_cfg,
|
||||
}
|
||||
|
||||
@property
|
||||
|
@ -1 +1 @@
|
||||
Subproject commit 0dbc980be3a4cb26ad5e3b16643a70a47623358a
|
||||
Subproject commit 78d70f0331f844e9fe9ed253175b9e0b28082153
|
@ -31,6 +31,7 @@ def list_to_tensor(x_list: list[Tensor], pattern="t b c -> b t c"):
|
||||
m = _create_mask(l, x_list[0].device)
|
||||
m = m.t().unsqueeze(-1) # (t b 1)
|
||||
m = rearrange(m, pattern)
|
||||
m = m.to(x)
|
||||
return x, m
|
||||
|
||||
|
||||
@ -256,7 +257,7 @@ class AdditiveMultiEmbedding(nn.Embedding):
|
||||
x = torch.cat(x_list)
|
||||
assert x.shape[1] == self.n_levels
|
||||
w = rearrange(self.weight, "(q k) d -> q k d", q=self.n_levels)
|
||||
x = F.one_hot(x, num_classes=self.n_tokens).float() # n q -> n q k
|
||||
x = F.one_hot(x, num_classes=self.n_tokens).to(w) # n q -> n q k
|
||||
x = einsum("q k d, n q k -> n d", w, x)
|
||||
x_list = x.split([*map(len, x_list)])
|
||||
return x_list
|
||||
@ -285,7 +286,7 @@ class SelectiveMultiEmbedding(nn.Embedding):
|
||||
w = repeat(self.weight[0], "d -> b d", b=len(x))
|
||||
|
||||
w = rearrange(w, "b (k d) -> b k d", k=self.n_tokens_per_level)
|
||||
x = F.one_hot(x, num_classes=self.n_tokens_per_level).float() # b t k
|
||||
x = F.one_hot(x, num_classes=self.n_tokens_per_level).to(w) # b t k
|
||||
x = einsum("b k d, b t k -> b t d", w, x)
|
||||
|
||||
x_list = [xi[:li] for xi, li in zip(x, map(len, x_list))]
|
||||
|
@ -1,5 +1,3 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
Loading…
Reference in New Issue
Block a user