import torch def sum(tensor, dim=None, keepdim=False): if dim is None: # sum up all dim return torch.sum(tensor) else: if isinstance(dim, int): dim = [dim] dim = sorted(dim) for d in dim: tensor = tensor.sum(dim=d, keepdim=True) if not keepdim: for i, d in enumerate(dim): tensor.squeeze_(d-i) return tensor def mean(tensor, dim=None, keepdim=False): if dim is None: # mean all dim return torch.mean(tensor) else: if isinstance(dim, int): dim = [dim] dim = sorted(dim) for d in dim: tensor = tensor.mean(dim=d, keepdim=True) if not keepdim: for i, d in enumerate(dim): tensor.squeeze_(d-i) return tensor def split_feature(tensor, type="split"): """ type = ["split", "cross"] """ C = tensor.size(1) if type == "split": return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] elif type == "cross": return tensor[:, 0::2, ...], tensor[:, 1::2, ...] def cat_feature(tensor_a, tensor_b): return torch.cat((tensor_a, tensor_b), dim=1) def pixels(tensor): return int(tensor.size(2) * tensor.size(3))