目次を表示する

LLM完全解説:スクラッチから理解する大規模言語モデル

推論の高速化 ── KV Cache・量子化・投機的デコーディング

推論の高速化 ── KV Cache・量子化・投機的デコーディング

LLM との関係:訓練が「コスト」の問題なら、推論は「レイテンシ」の問題。ユーザーが体感する「速さ」を決めるのは推論最適化。


推論最適化 — KV Cache・量子化・投機的デコーディング

この章で何ができるようになるか:LLM の推論がなぜ遅いのかをボトルネック単位で説明でき、KV Cache・量子化・投機的デコーディングの仕組みを理解できる。


なぜ LLM の推論は遅いのか

LLM の生成は自己回帰(auto-regressive):1トークンずつ順番に生成する。

「東京の名所は」→ モデル → 「浅」
「東京の名所は浅」→ モデル → 「草」
「東京の名所は浅草」→ モデル → 「寺」
「東京の名所は浅草寺」→ モデル → 「、」
...

100トークンの応答を生成するには、モデルの順伝播を100回実行する必要がある
→ 各回で全パラメータを読み込む
→ 70B パラメータ × 100回 = 7兆回のパラメータ読み込み

ボトルネック:生成フェーズはメモリ帯域幅律速(Memory Bandwidth Bound)。GPU の計算能力ではなく、メモリからパラメータを読み出す速度が制限要因。

H100 の性能:
  計算能力: 990 TFLOPS (FP16)
  メモリ帯域: 3.35 TB/s

70B パラメータ (FP16 = 140GB) の1回の推論:
  必要な帯域: 140 GB / (3.35 TB/s) ≈ 42ms
  → 秒間 ~24 トークン(帯域幅だけで制限される)

KV Cache:同じ計算を繰り返さない

自己回帰生成では、各ステップで過去の全トークンの Attention を再計算する…が、過去のトークンの K, V は変わらない

ステップ1: 入力 "東京" → Q1, K1, V1 を計算 → Attention(Q1, [K1], [V1])
ステップ2: 入力 "の"   → Q2, K2, V2 を計算 → Attention(Q2, [K1,K2], [V1,V2])
ステップ3: 入力 "名所" → Q3, K3, V3 を計算 → Attention(Q3, [K1,K2,K3], [V1,V2,V3])

K1, V1 はステップ1で計算済み → キャッシュして再利用
K2, V2 はステップ2で計算済み → キャッシュして再利用

ステップ3 では Q3 の計算と K3, V3 の計算だけ行い、
K1, K2 と V1, V2 はキャッシュから読む
class CachedAttention:
    def __init__(self):
        self.k_cache = None  # (batch, num_heads, cached_len, d_k)
        self.v_cache = None

    def forward(self, q, k, v, use_cache=True):
        if use_cache and self.k_cache is not None:
            # 新しい K, V をキャッシュに追加
            k = torch.cat([self.k_cache, k], dim=2)
            v = torch.cat([self.v_cache, v], dim=2)

        if use_cache:
            self.k_cache = k
            self.v_cache = v

        # Attention 計算(q は最新トークン1つだけ)
        return scaled_dot_product_attention(q, k, v)

KV Cache のメモリコスト

GPT-3 (175B) の KV Cache:
  1層あたり: 2 × d_model × seq_len × sizeof(FP16)
  = 2 × 12288 × 2048 × 2 bytes = 100 MB

  96層: 100 MB × 96 = 9.6 GB / リクエスト

  バッチサイズ 32: 9.6 GB × 32 = 307 GB
  → GPU メモリの大部分を KV Cache が占める

量子化(Quantization):パラメータを小さくする

FP16(16bit)のパラメータを INT8(8bit)や INT4(4bit)に圧縮する。

FP16: 70B パラメータ × 2 bytes = 140 GB(H100 2台必要)
INT8: 70B パラメータ × 1 byte  = 70 GB(H100 1台に収まる)
INT4: 70B パラメータ × 0.5 byte = 35 GB(A100 1台に収まる)

Post-Training Quantization(PTQ)

def quantize_to_int8(tensor: torch.Tensor) -> tuple:
    """FP16 テンソルを INT8 に変換"""
    # スケールファクターの計算
    abs_max = tensor.abs().max()
    scale = abs_max / 127.0  # INT8 の範囲: -128〜127

    # 量子化: 浮動小数点 → 整数
    quantized = torch.round(tensor / scale).clamp(-128, 127).to(torch.int8)

    return quantized, scale

def dequantize(quantized: torch.Tensor, scale: float) -> torch.Tensor:
    """INT8 を FP16 に復元"""
    return quantized.float() * scale

# 行列演算は INT8 のまま実行(GPU が INT8 演算をネイティブサポート)
# → メモリ帯域幅が半減 → 推論速度が約2倍

GPTQ / AWQ(高精度な4bit量子化)

GPTQ (2023):
  層ごとに量子化誤差を最小化するキャリブレーション
  → 4bit でも FP16 に近い精度を維持

AWQ (Activation-aware Weight Quantization):
  「重要な重み」(活性化値が大きい重みに対応する)を高精度で保持
  → 全体を均一に量子化するより精度が高い

投機的デコーディング(Speculative Decoding)

「小さなモデルで下書き → 大きなモデルで検証」。

通常の推論:
  大きなモデル(70B)で1トークンずつ生成 → 遅い

投機的デコーディング:
  1. 小さなモデル(7B)で K トークンを高速に下書き生成
  2. 大きなモデル(70B)で K トークンを一括検証(1回の順伝播)
  3. 検証に通ったトークンはそのまま採用、通らなかったら再生成

なぜ速くなるか:
  小さなモデルの推論は速い(パラメータが1/10)
  大きなモデルの検証は1回の順伝播(K トークン分を並列処理)
  多くのトークンは小さなモデルと大きなモデルで一致する
  → 実質的に K トークンを大きなモデル1回の推論で生成
def speculative_decode(draft_model, target_model, prompt_ids, K=5):
    """投機的デコーディング(簡略化)"""
    generated = prompt_ids.clone()

    while not is_finished(generated):
        # 1. Draft model で K トークンを推測
        draft_tokens = []
        draft_probs = []
        temp_ids = generated.clone()
        for _ in range(K):
            logits = draft_model(temp_ids)
            probs = torch.softmax(logits[:, -1], dim=-1)
            token = torch.multinomial(probs, 1)
            draft_tokens.append(token)
            draft_probs.append(probs)
            temp_ids = torch.cat([temp_ids, token], dim=1)

        # 2. Target model で K トークンを一括検証
        all_ids = torch.cat([generated] + draft_tokens, dim=1)
        target_logits = target_model(all_ids)  # 1回の順伝播で K トークン分

        # 3. 各トークンを受理/棄却(確率的に判定)
        accepted = 0
        for i in range(K):
            target_prob = torch.softmax(target_logits[:, -(K-i+1)], dim=-1)
            draft_prob = draft_probs[i]
            token = draft_tokens[i]

            # 受理確率 = min(1, target_prob / draft_prob)
            acceptance = min(1.0, target_prob[0, token] / draft_prob[0, token])
            if torch.rand(1) < acceptance:
                generated = torch.cat([generated, token], dim=1)
                accepted += 1
            else:
                # 棄却 → target model の分布からサンプリングし直す
                token = torch.multinomial(target_prob, 1)
                generated = torch.cat([generated, token], dim=1)
                break

    return generated

高速化の効果:ドラフトモデルの精度が高い場合、2〜3倍の高速化が得られる。


その他の最適化手法

Flash Attention

Attention の計算をメモリ効率よく行う。

通常の Attention:
  Q×K^T → (seq_len × seq_len) の行列をメモリに保持 → O(n²) メモリ

Flash Attention:
  Q, K, V をタイルに分割し、各タイルを GPU の SRAM(高速メモリ)で処理
  → 中間の (seq_len × seq_len) 行列を明示的にメモリに保持しない
  → メモリ O(n) に削減、速度も 2〜4倍向上

Grouped Query Attention(GQA)

通常の Multi-Head Attention:
  Q: 32ヘッド, K: 32ヘッド, V: 32ヘッド → KV Cache が大きい

GQA (LLaMA 2 / 3 が使用):
  Q: 32ヘッド, K: 8グループ, V: 8グループ
  → 4つの Q ヘッドが 1つの K/V グループを共有
  → KV Cache が 1/4 に削減

まとめ

最適化手法効果トレードオフ
KV CacheAttention の再計算を省略メモリ消費が増加
INT8 量子化メモリ半減、速度2倍わずかな精度低下
INT4 量子化メモリ1/4、速度〜3倍精度低下がやや大きい
投機的デコーディング2〜3倍高速化ドラフトモデルの追加メモリ
Flash Attentionメモリ O(n)、速度2〜4倍実装が複雑
GQAKV Cache 削減わずかな精度低下

次章(エピローグ)でシリーズ全体を振り返る。