some checks

This commit is contained in:
James Betker 2022-06-09 21:46:32 -06:00
parent 34005367fd
commit 07bdd865dc

View File

@ -21,6 +21,7 @@ def masked_channel_balancer(inp, proportion=1):
def channel_restriction(inp, low, high):
assert low > 0 and low < inp.shape[1] and high <= inp.shape[1]
m = torch.zeros_like(inp)
m[:,low:high] = 1
return inp * m