Bases: MessagePassing
The relational graph attentional operator from the “Relational Graph Attention Networks” paper.
Here, attention logits \(\mathbf{a}^{(r)}_{i,j}\) are computed for each relation type \(r\) with the help of both query and key kernels, i.e.
\[\mathbf{q}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{Q}^{(r)} \quad \textrm{and} \quad \mathbf{k}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot \mathbf{K}^{(r)}.\]
Two schemes have been proposed to compute attention logits \(\mathbf{a}^{(r)}_{i,j}\) for each relation type \(r\):
Additive attention
\[\mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j)\]
or multiplicative attention
\[\mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j.\]
If the graph has multi-dimensional edge features \(\mathbf{e}^{(r)}_{i,j}\), the attention logits \(\mathbf{a}^{(r)}_{i,j}\) for each relation type \(r\) are computed as
\[\mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + \mathbf{k}^{(r)}_j + \mathbf{W}_2^{(r)}\mathbf{e}^{(r)}_{i,j})\]
or
\[\mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j \cdot \mathbf{W}_2^{(r)} \mathbf{e}^{(r)}_{i,j},\]
respectively. The attention coefficients \(\alpha^{(r)}_{i,j}\) for each relation type \(r\) are then obtained via two different attention mechanisms: The within-relation attention mechanism
\[\alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{k \in \mathcal{N}_r(i)} \exp(\mathbf{a}^{(r)}_{i,k})}\]
or the across-relation attention mechanism
\[\alpha^{(r)}_{i,j} = \frac{\exp(\mathbf{a}^{(r)}_{i,j})} {\sum_{r^{\prime} \in \mathcal{R}} \sum_{k \in \mathcal{N}_{r^{\prime}}(i)} \exp(\mathbf{a}^{(r^{\prime})}_{i,k})}\]
where \(\mathcal{R}\) denotes the set of relations, i.e. edge types. Edge type needs to be a one-dimensional torch.long
tensor which stores a relation identifier \(\in \{ 0, \ldots, |\mathcal{R}| - 1\}\) for each edge.
To enhance the discriminative power of attention-based GNNs, this layer further implements four different cardinality preservation options as proposed in the “Improving Attention Mechanism in Graph Neural Networks via Cardinality Preservation” paper:
\[ \begin{align}\begin{aligned}\text{additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + \mathcal{W} \odot \sum_{j \in \mathcal{N}_r(i)} \mathbf{x}^{(r)}_j\\\text{scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= \psi(|\mathcal{N}_r(i)|) \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j\\\text{f-additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= \sum_{j \in \mathcal{N}_r(i)} (\alpha^{(r)}_{i,j} + 1) \cdot \mathbf{x}^{(r)}_j\\\text{f-scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= |\mathcal{N}_r(i)| \odot \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j\end{aligned}\end{align} \]
If attention_mode="additive-self-attention"
and concat=True
, the layer outputs heads * out_channels
features for each node.
If attention_mode="multiplicative-self-attention"
and concat=True
, the layer outputs heads * dim * out_channels
features for each node.
If attention_mode="additive-self-attention"
and concat=False
, the layer outputs out_channels
features for each node.
If attention_mode="multiplicative-self-attention"
and concat=False
, the layer outputs dim * out_channels
features for each node.
Please make sure to set the in_channels
argument of the next layer accordingly if more than one instance of this layer is used.
in_channels (int) – Size of each input sample.
out_channels (int) – Size of each output sample.
num_relations (int) – Number of relations.
num_bases (int, optional) – If set, this layer will use the basis-decomposition regularization scheme where num_bases
denotes the number of bases to use. (default: None
)
num_blocks (int, optional) – If set, this layer will use the block-diagonal-decomposition regularization scheme where num_blocks
denotes the number of blocks to use. (default: None
)
mod (str, optional) – The cardinality preservation option to use. ("additive"
, "scaled"
, "f-additive"
, "f-scaled"
, None
). (default: None
)
attention_mechanism (str, optional) – The attention mechanism to use ("within-relation"
, "across-relation"
). (default: "across-relation"
)
attention_mode (str, optional) – The mode to calculate attention logits. ("additive-self-attention"
, "multiplicative-self-attention"
). (default: "additive-self-attention"
)
heads (int, optional) – Number of multi-head-attentions. (default: 1
)
dim (int) – Number of dimensions for query and key kernels. (default: 1
)
concat (bool, optional) – If set to False
, the multi-head attentions are averaged instead of concatenated. (default: True
)
negative_slope (float, optional) – LeakyReLU angle of the negative slope. (default: 0.2
)
dropout (float, optional) – Dropout probability of the normalized attention coefficients which exposes each node to a stochastically sampled neighborhood during training. (default: 0
)
edge_dim (int, optional) – Edge feature dimensionality (in case there are any). (default: None
)
bias (bool, optional) – If set to False
, the layer will not learn an additive bias. (default: True
)
**kwargs (optional) – Additional arguments of torch_geometric.nn.conv.MessagePassing
.
Runs the forward pass of the module.
x (torch.Tensor) – The input node features. Can be either a [num_nodes, in_channels]
node feature matrix, or an optional one-dimensional node index tensor (in which case input features are treated as trainable node embeddings).
edge_index (torch.Tensor or SparseTensor) – The edge indices.
edge_type (torch.Tensor, optional) – The one-dimensional relation type/index for each edge in edge_index
. Should be only None
in case edge_index
is of type torch_sparse.SparseTensor
or torch.sparse.Tensor
. (default: None
)
edge_attr (torch.Tensor, optional) – The edge features. (default: None
)
size ((int, int), optional) – The shape of the adjacency matrix. (default: None
)
return_attention_weights (bool, optional) – If set to True
, will additionally return the tuple (edge_index, attention_weights)
, holding the computed attention weights for each edge. (default: None
)
Resets all learnable parameters of the module.
RetroSearch is an open source project built by @garambo | Open a GitHub Issue
Search and Browse the WWW like it's 1997 | Search results from DuckDuckGo
HTML:
3.2
| Encoding:
UTF-8
| Version:
0.7.4