1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
use tch::Kind;

pub trait NDATensorExt {
    fn to_ndarray(&self) -> ndarray::ArrayD<f32>;
    
    fn from_ndarray(array: ndarray::ArrayD<f32>) -> Self;
}

impl NDATensorExt for tch::Tensor {
    fn to_ndarray(&self) -> ndarray::ArrayD<f32> {
        let dims = self.size();
        let casted = self.to_kind(Kind::Float);
        let data = Vec::<f32>::from(&casted);
        return match dims.len() {
            0 => ndarray::arr0(data[0]).into_dyn(),
            1 => ndarray::Array1::from(data).into_dyn(),
            2 => ndarray::Array2::from_shape_vec((dims[0] as usize, dims[1] as usize), data).unwrap().into_dyn(),
            3 => ndarray::Array3::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize), data).unwrap().into_dyn(),
            4 => ndarray::Array4::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize, dims[3] as usize), data).unwrap().into_dyn(),
            5 => ndarray::Array5::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize, dims[3] as usize, dims[4] as usize), data).unwrap().into_dyn(),
            6 => ndarray::Array6::from_shape_vec((dims[0] as usize, dims[1] as usize, dims[2] as usize, dims[3] as usize, dims[4] as usize, dims[5] as usize), data).unwrap().into_dyn(),
            _ => panic!("Unsupported tensor shape"),
        };
    }
    
    fn from_ndarray(array: ndarray::ArrayD<f32>) -> Self {
        let shape = array.shape().to_owned();
        let mut l = 1;
        for i in shape.iter() {
            l *= i;
        }
        let tensor = tch::Tensor::of_slice(&array.into_shape((l, )).unwrap().to_vec());
        tensor.reshape(shape.iter().map(|x| *x as i64).collect::<Vec<i64>>().as_slice())
    }
}