1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
use crate::Tensor;
use std::borrow::Borrow;
#[derive(Debug, Clone, Copy)]
pub struct EmbeddingConfig {
pub sparse: bool,
pub scale_grad_by_freq: bool,
pub ws_init: super::Init,
pub padding_idx: i64,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
EmbeddingConfig {
sparse: false,
scale_grad_by_freq: false,
ws_init: super::Init::Randn { mean: 0., stdev: 1. },
padding_idx: -1,
}
}
}
#[derive(Debug)]
pub struct Embedding {
pub ws: Tensor,
config: EmbeddingConfig,
}
pub fn embedding<'a, T: Borrow<super::Path<'a>>>(
vs: T,
num_embeddings: i64,
embedding_dim: i64,
config: EmbeddingConfig,
) -> Embedding {
let vs = vs.borrow();
Embedding { ws: vs.var("weight", &[num_embeddings, embedding_dim], config.ws_init), config }
}
impl super::module::Module for Embedding {
fn forward(&self, xs: &Tensor) -> Tensor {
Tensor::embedding(
&self.ws,
xs,
self.config.padding_idx,
self.config.scale_grad_by_freq,
self.config.sparse,
)
}
}