月影

日々の雑感

【PyTorchで実装】TransformerのSelf-AttentionとMulti-Head Attentionを解説

ブログ連載:AIの「文脈を読む力」の秘密 - Attention機構 完全解説

 

 

【第2部】TransformerとAttention:ChatGPTを誕生させた革命的技術

 

 

第4回:【専門編】Transformerの核心:Self-AttentionとMulti-Head Attentionの実装

 

前回の【入門編】では、Transformerの心臓部である「Self-Attention」が、文章を「関係性のネットワーク」として捉える革新的なアイデアであり、QKVモデルやMulti-Head Attentionといった概念で動いていることを解説しました。

最終回となる今回は【専門編】として、これらの概念がどのような数式で定義され、実際のコードでどう実装されているのかを深く掘り下げます。現代のあらゆるLLMの基盤となっている、この強力なメカニズムの核心に迫りましょう。


 

Self-Attention:Query, Key, Value (QKV)パラダイム

 

Self-Attentionは、入力された単語系列内の各要素が、系列内の他の全要素に対してどのように関連しているかを動的に計算する機構です。これはQuery (Q), Key (K), Value (V) という3つのベクトルでモデル化されます。

入力された単語の埋め込みベクトル x は、それぞれ異なる学習可能な重み行列 Wq, Wk, Wv によって線形変換され、Q, K, Vベクトルが生成されます。

アテンションの計算式は**「Scaled Dot-Product Attention」**として、以下のように表されます。

この式の各部分を分解してみましょう。

  1. : ここが「関連度スコア」の計算部分です。ある単語のQuery (Q)ベクトルと、他のすべての単語のKey (K)ベクトルの内積を計算します。これにより、単語ペア間の類似度が高いほど、大きな値を持つスコア行列が生成されます。

  2. : スコアをKeyベクトルの次元数 d_k平方根で割る「スケーリング」処理です。これは、次元数 d_k が大きくなった際に内積の値が過度に大きくなり、後のsoftmax関数の勾配が消失してしまうのを防ぐための、学習を安定させる重要なテクニックです。

  3. softmax: スケーリングされたスコアを、合計が1になる確率分布(attention_weights)に正規化します。これが入門編で解説した「注目度の割合」に相当します。

  4. : 最後に、算出された注目度の重みと、各単語のValue (V)ベクトルとの加重平均をとります。これにより、文脈に応じて動的に重み付けされた、新しい単語表現ベクトルが完成します。

 

Scaled Dot-Product Attentionの実装

 

上記の計算式をPyTorchの関数として実装すると、以下のようになります。

Python
 
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Scaled Dot-Product Attentionを計算する関数
    """
    d_k = query.size(-1)
    # 1. QとK^Tの内積でスコアを計算し、スケーリング
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 2. (オプション) マスクを適用(未来の単語を見ないようにするためなど)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
        
    # 3. softmaxでAttentionの重みを計算
    attention_weights = F.softmax(scores, dim=-1)
    
    # 4. 重みとVの加重平均で最終的な出力を計算
    output = torch.matmul(attention_weights, value)
    
    return output, attention_weights

 

Multi-Head Attention:多様な表現空間の学習

 

入門編で「専門家チーム」に例えたように、Multi-Head Attentionは、単一の視点だけでは捉えきれない、多様な言語的特徴(例:構文関係、意味関係、共参照関係など)を同時に学習するための仕組みです。

これは、Scaled Dot-Product Attentionを複数(論文ではh=8個)並列で実行することで実現されます。

処理フロー:

  1. 入力されたQ, K, Vを、それぞれh個の「ヘッド」に分割して、異なる線形層で変換します。

  2. 各ヘッドで、Scaled Dot-Product Attentionを並列に計算します。各ヘッドは異なる「表現部分空間」で、異なる種類の関係性に注目することを学習します。

  3. 各ヘッドの出力(コンテキストベクトル)を連結(concatenate)します。

  4. 連結したベクトルを、最後の線形層(W_o)を通して最終的な出力とします。

 

Multi-Head Attentionの実装

 

Python
 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, h):
        super(MultiHeadAttention, self).__init__()
        assert d_model % h == 0
        
        self.d_k = d_model // h
        self.h = h
        
        # Q, K, V と最終出力のための線形層
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 1. 線形層を通してQ, K, Vを生成し、ヘッド数hで分割
        # [batch, seq_len, d_model] -> [batch, h, seq_len, d_k]
        q = self.linear_q(query).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        k = self.linear_k(key).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        v = self.linear_v(value).view(batch_size, -1, self.h, self.d_k).transpose(1, 2)
        
        # 2. 各ヘッドでAttentionを並列計算
        x, self.attention_weights = scaled_dot_product_attention(q, k, v, mask=mask)
        
        # 3. ヘッドの出力を連結し、形状を元に戻す
        # [batch, h, seq_len, d_k] -> [batch, seq_len, d_model]
        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.h * self.d_k)
        
        # 4. 最後の線形層を通して最終出力を得る
        return self.linear_out(x)

# --- 実行例 ---
d_model = 512 # モデルの次元数
h = 8         # ヘッドの数
batch_size = 64
seq_len = 10

# モデルとダミーデータの準備
multi_head_attn = MultiHeadAttention(d_model, h)
q = torch.randn(batch_size, seq_len, d_model)
k = torch.randn(batch_size, seq_len, d_model)
v = torch.randn(batch_size, seq_len, d_model)

# Multi-Head Attentionの実行
output = multi_head_attn(q, k, v)
print(f"Output Shape: {output.shape}") # -> [64, 10, 512]

 

まとめと次のステップ

 

全4回にわたり、AIの「文脈を読む力」を支えるAttention機構の進化を追ってきました。

RNNの「情報のボトルネック」を解決したCross-Attentionから始まり、現代のLLMの基盤であるTransformerのSelf-Attention、そしてその強力な拡張であるMulti-Head Attentionへと至る道のりを見てきました。

Transformerのアーキテクチャは、系列内の長距離依存関係を効率的に捉え、かつ計算の並列化を可能にすることで、RNNベースのモデルの限界を克服しました。このブレークスルーが、ChatGPTをはじめとする生成AIの爆発的な進化の引き金となったのです。

本シリーズが、現代AIの核心技術を理解するための一助となれば幸いです。最後までお読みいただき、ありがとうございました。


2025年8月19日