Transformer
多头注意力机制
多头注意力机制
实现一
class Head(nn.Module):
def __init__(self, d_model, head_size):
super().__init__()
self.head_size = head_size
self.query = nn.Linear(d_model, head_size)
self.key = nn.Linear(d_model, head_size)
self.value = nn.Linear(d_model, head_size)
self.dropout = nn.Dropout(0.1)
def forward(self, q, k, v, mask=None):
q = self.query(q)
k = self.key(k)
wei = q @ k.transpose(-2, -1) * (self.head_size**-0.5)
if mask is not None:
wei = wei.masked_fill(mask == 0, float("-inf"))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(v)
return wei @ v
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
assert d_model % n_heads == 0
self.heads = nn.ModuleList(
[Head(d_model, d_model // n_heads) for _ in range(n_heads)]
)
self.pro = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, q, k, v, mask):
out = torch.cat([h(q, k, v, mask) for h in self.heads], dim=-1)
out = self.dropout(out)
return self.dropout(out)
实现二
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k)
# By subtracting the maximum value, we ensure the largest exponent will be 0
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)