52 lines
1.3 KiB
Python
52 lines
1.3 KiB
Python
|
import torch
|
||
|
|
||
|
|
||
|
def sum(tensor, dim=None, keepdim=False):
|
||
|
if dim is None:
|
||
|
# sum up all dim
|
||
|
return torch.sum(tensor)
|
||
|
else:
|
||
|
if isinstance(dim, int):
|
||
|
dim = [dim]
|
||
|
dim = sorted(dim)
|
||
|
for d in dim:
|
||
|
tensor = tensor.sum(dim=d, keepdim=True)
|
||
|
if not keepdim:
|
||
|
for i, d in enumerate(dim):
|
||
|
tensor.squeeze_(d-i)
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def mean(tensor, dim=None, keepdim=False):
|
||
|
if dim is None:
|
||
|
# mean all dim
|
||
|
return torch.mean(tensor)
|
||
|
else:
|
||
|
if isinstance(dim, int):
|
||
|
dim = [dim]
|
||
|
dim = sorted(dim)
|
||
|
for d in dim:
|
||
|
tensor = tensor.mean(dim=d, keepdim=True)
|
||
|
if not keepdim:
|
||
|
for i, d in enumerate(dim):
|
||
|
tensor.squeeze_(d-i)
|
||
|
return tensor
|
||
|
|
||
|
|
||
|
def split_feature(tensor, type="split"):
|
||
|
"""
|
||
|
type = ["split", "cross"]
|
||
|
"""
|
||
|
C = tensor.size(1)
|
||
|
if type == "split":
|
||
|
return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
|
||
|
elif type == "cross":
|
||
|
return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
|
||
|
|
||
|
|
||
|
def cat_feature(tensor_a, tensor_b):
|
||
|
return torch.cat((tensor_a, tensor_b), dim=1)
|
||
|
|
||
|
|
||
|
def pixels(tensor):
|
||
|
return int(tensor.size(2) * tensor.size(3))
|