1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
use tch::Tensor;
pub trait TensorExt{
fn sum_dim(&self, dim: i64) -> Tensor;
fn sum_kdim(&self, dim: i64) -> Tensor;
fn sum_dims<const D: usize>(&self, dims: [i64;D]) -> Tensor;
fn sum_kdims<const D: usize>(&self, dims: [i64;D]) -> Tensor;
}
impl TensorExt for Tensor {
fn sum_dim(&self, dim: i64) -> Tensor {
let typ = self.kind();
self.sum_dim_intlist(Some(&[dim][..]), false, typ)
}
fn sum_kdim(&self, dim: i64) -> Tensor {
let typ = self.kind();
self.sum_dim_intlist(Some(&[dim][..]), true, typ)
}
fn sum_dims<const D: usize>(&self, dims: [i64; D]) -> Tensor {
let typ = self.kind();
self.sum_dim_intlist(Some(&dims[..]), false, typ)
}
fn sum_kdims<const D: usize>(&self, dims: [i64; D]) -> Tensor {
let typ = self.kind();
self.sum_dim_intlist(Some(&dims[..]), true, typ)
}
}