1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
/*!
Implementation of Gabor filter.
 */

use tch::{Tensor, Kind};
use crate::tensor_ext::TensorExt;
use crate::tensor_init;
use crate::ops_2d;

/**
Generates a Gabor filter.
# Arguments
- size: usize - The size of the filter
- theta: f64 - The orientation of the filter in degrees
- sigma: f64 - The standard deviation of the Gaussian envelope
- lambda: f64 - The wavelength of the sinusoidal factor
- psi: f64 - The phase offset of the sinusoidal factor
- gamma: f64 - The spatial aspect ratio
- device: Device - The device to store the tensor on
# Returns
- [size, size] float tensor - The Gabor filter
 */
pub fn gabor_filter(
    size: usize,
    theta: f64,
    sigma: f64,
    lambda: f64,
    psi: f64,
    gamma: f64,
    device: tch::Device
) -> Tensor {
    let xy = tensor_init::position_tensor_2d((size, size), 1, (Kind::Float, device));
    // Rotating the coordinates
    let xy_rot = ops_2d::rotate_2d(&xy, theta);
    // Scaling the coordinates
    let xy_rot_scaled = ops_2d::scale_2d(&xy_rot, &[gamma, 1.0]).squeeze();
    // Calculating the Gaussian envelope
    let gauss_env = xy_rot_scaled.square().sum_dim(-3) / (2.0 * sigma.powi(2));
    let gauss_env = gauss_env.neg().exp();
    // Calculating the sinusoidal factor
    let cos_factor = (2.0 * std::f64::consts::PI / lambda) * xy_rot_scaled.select(-3, 1) + psi;
    let cos_factor = cos_factor.cos();
    // Calculating the Gabor filter
    let gabor = gauss_env * cos_factor;
    gabor
}

/**
Applies Gabor filters to an input tensor.
> Note : 
> The filters are using - 1.0 to 1.0 as range no matter the input size.
> Meaning not matter the input size, the filter will look the same.
# Arguments
- input: [N, 1, H, W] - The input tensor
- angle_count: usize - The number of angles to use
- filter_size: usize - The size of the filters
- frequencies: &[f64] - The frequencies to use
- sigma: f64 - The standard deviation of the Gaussian envelope
# Returns
- Tensor [N, angle_count * frequencies.len(), H, W] - The output tensor
    The C dimmension correspond to filters with different angles and frequencies.
    The filter are in the following order:
        - angle 0, frequency 0
        - angle 0, frequency 1
        - ...
        - angle 1, frequency 0
        - ...
 */
pub fn apply_gabor_filter(
    input: &Tensor,
    angle_count: usize,
    filter_size: usize,
    frequencies: &[f64],
    sigma: f64
) -> Tensor {
    assert!(input.size().len() == 4);
    assert!(input.size()[1] == 1);
    let filters = (0..angle_count).flat_map(|i| {
        let theta = (i as f64) * std::f64::consts::PI / (angle_count as f64);
        frequencies.iter().map(move |&lambda| {
            gabor_filter(filter_size, theta, sigma, lambda, 0.0, 1.0, input.device())
        })
    }).collect::<Vec<_>>();
    let filters = Tensor::stack(&filters, 0).view([(angle_count  * frequencies.len()) as i64, 1, filter_size as i64, filter_size as i64]);
    input.conv2d_padding(&filters, None::<&Tensor>, &[1, 1], "same", &[1, 1], 1)
}

#[cfg(test)]
mod test {
    use super::*;
    use tch::{index::*};
    
    #[test]
    fn test_gabor(){
        let gabor = gabor_filter(60, std::f64::consts::PI/2.0, 0.30, 1.0, 0.0, 1.0, tch::Device::Cpu);
        for i in 0..gabor.size()[0]{
            for j in 0..gabor.size()[1]{
                let gabor = f32::from(gabor.i((i,j)));
                let gabor = (gabor * 255.0 + 127.0) as u8;
                print!("\x1b[48;2;{};{};{}m  ", gabor, gabor, gabor);
            }
            println!("\x1b[0m");
        }
    }

    #[test]
    fn test_apply_gabor_filter(){
        let input = Tensor::randn(&[2, 1, 60, 60], (Kind::Float, tch::Device::Cpu));
        let output = apply_gabor_filter(&input, 8, 31, &[0.06, 0.12, 0.24, 0.48], 0.50);
        assert!(output.size() == [2, 32, 60, 60]);
    }
}