从矩阵乘法到 Transformer:一种全新视角

前言

最近在看 Transformer 论文,里面大量用到矩阵乘法,突然觉得以前学的线性代数有了全新的意义。 这篇博客记录一下我对矩阵乘法的新理解——从最基本的定义,一路到 Transformer 中的 Attention 机制。

矩阵乘法到底是什么?

最基本的形式:$A$ 是 $m \times n$ 矩阵,$B$ 是 $n \times p$ 矩阵,乘积 $C = AB$ 是 $m \times p$ 矩阵:

$$C_{ij} = \sum_{k=1}^{n} A_{ik} B_{kj}$$

但光是公式不够直观,来看三种理解方式。

1. 点积视角

$C$ 的第 $i$ 行第 $j$ 列 = $A$ 的第 $i$ 行向量与 $B$ 的第 $j$ 列向量的点积:

$$C_{ij} = \vec{A}_{i,:} \cdot \vec{B}_{:,j}$$

这是课本最常用的定义,适合手算,但不太能揭示"为什么需要矩阵乘法"。

2. 线性变换视角

这是最重要的理解

把矩阵 $A$ 看作一个函数:它接受一个 $n$ 维向量,输出一个 $m$ 维向量。

$$y = A x$$
  • 矩阵乘法 = 对向量进行线性变换(旋转、缩放、投影)
  • 两个矩阵相乘 $AB$ = 先做 $B$ 变换,再做 $A$ 变换

例子: 旋转矩阵 $R(\theta) = \begin{bmatrix} \cos\theta & -\sin\theta \ \sin\theta & \cos\theta \end{bmatrix}$ 作用于向量 $v = \begin{bmatrix}1 \ 0\end{bmatrix}$:

$$R(90^\circ) v = \begin{bmatrix} 0 & -1 \\ 1 & 0 \end{bmatrix} \begin{bmatrix} 1 \\ 0 \end{bmatrix} = \begin{bmatrix} 0 \\ 1 \end{bmatrix}$$

向量从 $(1,0)$ 被旋转到了 $(0,1)$ — 这就是线性变换的直观含义。

3. 列向量视角

$C = AB$ 也可以看成:$C$ 的第 $j$ 列 = $A$ 乘以 $B$ 的第 $j$ 列:

$$C_{:,j} = A \cdot B_{:,j}$$

换句话说,$B$ 的每一列都被 $A$ 线性变换成了 $C$ 的对应列。 当 $B$ 的列是数据样本时,这就是批量线性变换——一次性把整个数据集从 $n$ 维空间映射到 $m$ 维空间。

这跟神经网络的全连接层 $y = Wx + b$ 完全对应:$W$ 就是权重矩阵,$x$ 是输入向量。


从矩阵乘法到 Transformer

Transformer 的核心是 Self-Attention,关键运算就是矩阵乘法。

Attention 的矩阵形式

给定 Query、Key、Value 矩阵 $Q, K, V \in \mathbb{R}^{n \times d}$($n$ 为序列长度,$d$ 为维度):

$$\text{Attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^\top}{\sqrt{d}}\right)V$$

这里每一步都是矩阵乘法:

  1. $QK^\top$:$n \times d$ 乘以 $d \times n$,得到 $n \times n$ 的注意力分数矩阵
  2. $\times V$:Softmax 后的 $n \times n$ 矩阵乘以 $n \times d$ 的 $V$,得到最终的加权表示

具体算例(序列长度=3,维度=2):

$$Q = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 1 \end{bmatrix},\quad K^\top = \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 1 \end{bmatrix}$$$$QK^\top = \begin{bmatrix} 1 & 0 & 1 \\ 0 & 1 & 1 \\ 1 & 1 & 2 \end{bmatrix}$$

Softmax 后每一行代表一个 token 对其他 token 的注意力权重,再乘以 $V$ 得到融合了全局信息的输出。

为什么矩阵乘法在这里如此高效?

GPU 的 Tensor Core 本质就是大规模矩阵乘法引擎。把 Attention 写成矩阵乘法,意味着:

  • 可以并行计算所有 token 之间的注意力
  • 充分利用 GPU 的并行能力
  • 训练和推理速度大幅提升

这正是"软件设计匹配硬件架构"的典范——把算法表达成矩阵乘法,让硬件暴力求解


总结

视角 核心思想 对应应用
点积视角 行与列对应位置相乘再求和 手算推导
线性变换视角 矩阵 = 函数,乘法 = 复合变换 全连接层 $Wx$
列向量视角 批量变换数据样本 Batch Inference
Attention 视角 $QK^\top$ 计算相似度矩阵 Transformer

矩阵乘法不仅仅是线性代数的一个运算——它是连接传统机器学习(线性回归、PCA)和现代深度学习(Transformer、CNN)的共同语言。理解了矩阵乘法,就拿到了读懂深度学习论文的钥匙。

使用 Hugo 构建
主题 StackJimmy 设计