Function tch_utils::tensor_init::position_tensor_2d
source · pub fn position_tensor_2d(
size: (usize, usize),
n: usize,
options: (Kind, Device)
) -> TensorExpand description
Generate a tensor that contains a vec2 from (-1, -1) to (1, 1) depending on the position.
Arguments
size: (usize, usize) - The size of the tensor n: usize - The number of samples options: (tch::Kind, tch::Device) - The kind and device of the tensor
Returns
Tensor - The tensor containing the positions [N, 2, H, W] The 2nd dimension contains the x and y position with y being the first element
Example
let pos = position_tensor_2d((3, 3), 1, (tch::Kind::Float, tch::Device::Cpu));
let expected = Tensor::of_slice(&[
-1.0, -1.0, -1.0,
0.0, 0.0, 0.0,
1.0, 1.0, 1.0,
-1.0, 0.0, 1.0,
-1.0, 0.0, 1.0,
-1.0, 0.0, 1.0,
]).view([1, 2, 3, 3]);
eprintln!("pos: {:?}", Vec::<f64>::from(&pos));
assert!(f64::from((pos - expected).abs().sum(tch::Kind::Float)) < 1e-6);