My App
Transformer

多头注意力机制

多头注意力机制

实现一

class Head(nn.Module):
    def __init__(self, d_model, head_size):
        super().__init__()
        self.head_size = head_size
        self.query = nn.Linear(d_model, head_size)
        self.key = nn.Linear(d_model, head_size)
        self.value = nn.Linear(d_model, head_size)
 
        self.dropout = nn.Dropout(0.1)
 
    def forward(self, q, k, v, mask=None):
        q = self.query(q)
        k = self.key(k)
        wei = q @ k.transpose(-2, -1) * (self.head_size**-0.5)
        if mask is not None:
            wei = wei.masked_fill(mask == 0, float("-inf"))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
 
        v = self.value(v)
        return wei @ v
 
 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
 
        self.heads = nn.ModuleList(
            [Head(d_model, d_model // n_heads) for _ in range(n_heads)]
        )
        self.pro = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(0.1)
 
    def forward(self, q, k, v, mask):
        out = torch.cat([h(q, k, v, mask) for h in self.heads], dim=-1)
        out = self.dropout(out)
        return self.dropout(out)

实现二

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
 
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale
 
        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        # By subtracting the maximum value, we ensure the largest exponent will be 0
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)
 
        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

On this page