1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
use crate::Tensor;
use std::borrow::Borrow;
#[derive(Debug, Clone, Copy)]
pub struct GroupNormConfig {
pub cudnn_enabled: bool,
pub eps: f64,
pub affine: bool,
pub ws_init: super::Init,
pub bs_init: super::Init,
}
impl Default for GroupNormConfig {
fn default() -> Self {
GroupNormConfig {
cudnn_enabled: true,
eps: 1e-5,
affine: true,
ws_init: super::Init::Const(1.),
bs_init: super::Init::Const(0.),
}
}
}
#[derive(Debug)]
pub struct GroupNorm {
config: GroupNormConfig,
pub ws: Option<Tensor>,
pub bs: Option<Tensor>,
pub num_groups: i64,
pub num_channels: i64,
}
pub fn group_norm<'a, T: Borrow<super::Path<'a>>>(
vs: T,
num_groups: i64,
num_channels: i64,
config: GroupNormConfig,
) -> GroupNorm {
let vs = vs.borrow();
let (ws, bs) = if config.affine {
let ws = vs.var("weight", &[num_channels], config.ws_init);
let bs = vs.var("bias", &[num_channels], config.bs_init);
(Some(ws), Some(bs))
} else {
(None, None)
};
GroupNorm { config, ws, bs, num_groups, num_channels }
}
impl super::module::Module for GroupNorm {
fn forward(&self, xs: &Tensor) -> Tensor {
Tensor::group_norm(
xs,
self.num_groups,
self.ws.as_ref(),
self.bs.as_ref(),
self.config.eps,
self.config.cudnn_enabled,
)
}
}