-
class
AttentionHead
(
nn
.
Module
):
-
"""
-
A single attention head.
-
This module is used in the MultiHeadAttention module.
-
"""
-
def
__init__
(
self
,
hidden_size
,
attention_head_size
,
dropout
,
bias
=
True
):
-
super
().
__init__
()
-
self
.
hidden_size
=
hidden_size
-
self
.
attention_head_size
=
attention_head_size
-
# Create the query, key, and value projection layers
-
self
.
query
=
nn
.
Linear
(
hidden_size
,
attention_head_size
,
bias
=
bias
)
-
self
.
key
=
nn
.
Linear
(
hidden_size
,
attention_head_size
,
bias
=
bias
)
-
self
.
value
=
nn
.
Linear
(
hidden_size
,
attention_head_size
,
bias
=
bias
)
-
-
self
.
dropout
=
nn
.
Dropout
(
dropout
)
-
-
def
forward
(
self
,
x
):
-
# Project the input into query, key, and value
-
# The same input is used to generate the query, key, and value,
-
# so it's usually called self-attention.
-
# (batch_size, sequence_length, hidden_size) -> (batch_size, sequence_length, attention_head_size)
-
query
=
self
.
query
(
x
)
-
key
=
self
.
key
(
x
)
-
value
=
self
.
value
(
x
)
-
# Calculate the attention scores
-
# softmax(Q*K.T/sqrt(head_size))*V
-
attention_scores
=
torch
.
matmul
(
query
,
key
.
transpose
(-
1
,
-
2
))
-
attention_scores
=
attention_scores
/
math
.
sqrt
(
self
.
attention_head_size
)
-
attention_probs
=
nn
.
functional
.
softmax
(
attention_scores
,
dim
=-
1
)
-
attention_probs
=
self
.
dropout
(
attention_probs
)
-
# Calculate the attention output
-
attention_output
=
torch
.
matmul
(
attention_probs
,
value
)
-
return
(
attention_output
,
attention_probs
)
class MultiHeadAttention(nn.Module):
"""
Multi-head attention module.
This module is used in the TransformerEncoder module.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config["hidden_size"]
self.num_attention_heads = config["num_attention_heads"]
self.attention_head_size = self.hidden_size // self.num_attention_heads
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.qkv_bias = config["qkv_bias"]
self.heads = nn.ModuleList([])
for _ in range(self.num_attention_heads):
head = AttentionHead(
self.hidden_size,
self.attention_head_size,
config["attention_probs_dropout_prob"],
self.qkv_bias
)
self.heads.append(head)
self.output_projection = nn.Linear(self.all_head_size, self.hidden_size)
self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])
def forward(self, x, output_attentions=False):
attention_outputs = [head(x) for head in self.heads]
attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1)
attention_output = self.output_projection(attention_output)
attention_output = self.output_dropout(attention_output)
if not output_attentions:
return (attention_output, None)
else:
attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1)
return (attention_output, attention_probs)