1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use image;
use tch::{Tensor, Kind};
pub trait ImageTensorExt {
fn to_image(&self) -> image::DynamicImage;
fn from_image(image: image::DynamicImage) -> Self;
}
impl ImageTensorExt for Tensor {
fn to_image(&self) -> image::DynamicImage {
let size = self.size();
let kind = self.kind();
assert!(size.len() == 3, "Tensor must be of shape [C, H, W] (got {:?})", size);
let [channels, height, width] = size[0..3] else { unreachable!()};
assert!(channels <= 4 && channels >= 1, "Tensor must have 4, 3, 2 or 1 channels (got {:?})", channels);
assert!(kind != Kind::ComplexFloat && kind != Kind::ComplexDouble, "Tensor must be non complex (got {:?})", kind);
let tensor = match channels {
4 | 3 => self.shallow_clone(),
2 => {
let z = Tensor::zeros(&[1, height, width], (tch::Kind::Float, self.device()));
Tensor::cat(&[&z, &self], 0)
},
1 => self.repeat(&[3, 1, 1]),
_ => unreachable!(),
};
let tensor = tensor.permute(&[2, 1, 0]);
match (channels, kind) {
(1|2|3, Kind::Uint8)=>{
let data = Vec::<u8>::from(tensor);
image::DynamicImage::ImageRgb8(image::ImageBuffer::from_raw(width as u32, height as u32, data).unwrap())
},
(4, Kind::Uint8)=>{
let data = Vec::<u8>::from(tensor);
image::DynamicImage::ImageRgba8(image::ImageBuffer::from_raw(width as u32, height as u32, data).unwrap())
},
(1|2|3, _)=>{
let tensor = tensor.to_kind(Kind::Float);
let data = Vec::<f32>::from(tensor);
image::DynamicImage::ImageRgb32F(image::ImageBuffer::from_raw(width as u32, height as u32, data).unwrap())
},
(4, _)=>{
let tensor = tensor.to_kind(Kind::Float);
let data = Vec::<f32>::from(tensor);
image::DynamicImage::ImageRgba32F(image::ImageBuffer::from_raw(width as u32, height as u32, data).unwrap())
},
_ => unreachable!(),
}
}
fn from_image(image: image::DynamicImage) -> Self {
let (width, height) = (image.width(), image.height());
let image = image.to_rgba32f();
let data = image.into_vec();
let tensor = Tensor::of_slice(&data);
tensor.reshape(&[width as i64, height as i64, 4]).permute(&[2, 1, 0])
}
}
#[cfg(test)]
mod tests {
use crate::utils::assert_eq_tensor;
use super::*;
use tch::{Tensor};
#[test]
fn test_image_tensor() {
let image = image::open("test-assets/convert/basic.png").unwrap();
let tensor = Tensor::from_image(image.clone());
let image2 = tensor.to_image();
let tensor2 = Tensor::from_image(image2.clone());
assert_eq_tensor(&tensor, &tensor2);
let image = image::open("test-assets/convert/cat.jpg").unwrap();
let tensor = Tensor::from_image(image.clone());
let image2 = tensor.to_image();
image2.to_rgb8().save("/tmp/cat.jpg").unwrap();
let tensor2 = Tensor::from_image(image2.clone());
assert_eq_tensor(&tensor, &tensor2);
}
}