forked from mrq/DL-Art-School
support projecting to vectors
This commit is contained in:
parent
86694aef4e
commit
6b43915eb8
|
@ -653,11 +653,14 @@ class ContrastiveTrainingWrapper(nn.Module):
|
||||||
}
|
}
|
||||||
return groups
|
return groups
|
||||||
|
|
||||||
def get_codes(self, mel):
|
def get_codes(self, mel, project=False):
|
||||||
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
||||||
_, proj = self.m2v.projector(proj)
|
_, proj = self.m2v.projector(proj)
|
||||||
codes = self.quantizer.get_codes(proj)
|
if project:
|
||||||
return codes
|
proj, _ = self.quantizer(proj)
|
||||||
|
return proj
|
||||||
|
else:
|
||||||
|
return self.quantizer.get_codes(proj)
|
||||||
|
|
||||||
def reconstruct(self, mel):
|
def reconstruct(self, mel):
|
||||||
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
proj = self.m2v.input_blocks(mel).permute(0,2,1)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
basepath = 'Y:/clips/podcasts-0'
|
basepath = 'Y:\\bigasr_dataset\\hifi_tts'
|
||||||
|
|
||||||
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
|
english_file = os.path.join(basepath, 'transcribed-oco-realtext.tsv')
|
||||||
if not os.path.exists(english_file):
|
if not os.path.exists(english_file):
|
||||||
|
|
|
@ -330,13 +330,14 @@ class Mel2vecCodesInjector(Injector):
|
||||||
self.m2v = get_music_codegen()
|
self.m2v = get_music_codegen()
|
||||||
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
|
del self.m2v.m2v.encoder # This is a big memory sink which will not get used.
|
||||||
self.needs_move = True
|
self.needs_move = True
|
||||||
|
self.inj_vector = opt_get(opt, ['vector'], False)
|
||||||
|
|
||||||
def forward(self, state):
|
def forward(self, state):
|
||||||
mels = state[self.input]
|
mels = state[self.input]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if self.needs_move:
|
if self.needs_move:
|
||||||
self.m2v = self.m2v.to(mels.device)
|
self.m2v = self.m2v.to(mels.device)
|
||||||
codes = self.m2v.get_codes(mels)
|
codes = self.m2v.get_codes(mels, project=self.inj_vector)
|
||||||
return {self.output: codes}
|
return {self.output: codes}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user