1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
/*!
Utilities for the utility crate :D
 */

use std::{path::{ PathBuf}, str::FromStr};

use tch::Tensor;

#[cfg(feature = "ndarray")]
use crate::ndarray::NDATensorExt;

pub fn assert_eq_tensor(a: &Tensor, b: &Tensor) {
    assert_eq_tensor_d(a, b, 1e-5);
}

pub fn assert_eq_tensor_d(a: &Tensor, b: &Tensor, max_delta: f64) {
    assert_eq!(a.size(), b.size(), "Tensors must have the same shape");
    let delta = f64::from((a - b).sum(tch::Kind::Float)).abs();
    assert!(delta < max_delta, "Tensors must be equal (delta: {})", delta);
}

pub fn dirty_load(path: &str) -> Tensor {
    let path = PathBuf::from_str(path).unwrap();
    if !path.exists(){
        panic!("Asset not found: {:?}", path)
    }
    
    return match path.extension(){
        Some(ext) if ext == "pt" => Tensor::load(path).expect("Failed to load asset"),
        Some(ext) if ext == "png" || ext == "jpeg" || ext == "jpg" => tch::vision::image::load(path).expect("Failed to load asset"),
        Some(ext) if ext == "npy" => {
            #[cfg(feature = "ndarray")]
            {
                let array = ndarray_npy::read_npy(path).expect("Failed to load asset");
                Tensor::from_ndarray(array)
            }
            #[cfg(not(feature = "ndarray"))]
            {
                panic!("loading npy files requires ndarray feature")
            }
        }
        _ => panic!("Asset file unsupported: {:?}", path)
    };
}

pub fn assert_tensor_asset(tensor: &Tensor, asset: &str) {
    let asset = dirty_load(asset);
    assert_eq_tensor(tensor, &asset);
}

pub fn assert_tensor_asset_d(tensor: &Tensor, asset: &str, max_delta: f64) {
    let asset = dirty_load(asset);
    assert_eq_tensor_d(tensor, &asset, max_delta);
}