写在前面
最近在看医学图像分割的多模态方法,读到 IEEE TMI 2023 的 LViT(Language meets Vision Transformer in Medical Image Segmentation),觉得它的设计思路很有意思——把放射科医生的文本描述(比如"边界清晰的圆形病灶")直接用 BERT 编码后注入分割网络。
这篇博客是我对 LViT 的完整走读,包含架构拆解、数据流追踪、关键代码分析和一些发现的问题。论文链接在 arXiv:2206.14718。
一、动机:文本能带来什么?
纯视觉的医学图像分割有个痛点:病灶形态千变万化,单靠像素级标注训练出来的模型泛化能力有限。LViT 的核心假设是:
放射科医生的文本描述包含了高级语义先验(位置、形状、大小、边界特征),这些信息可以作为视觉特征的辅助,引导模型关注正确的区域。
比如 “a small round ground-glass opacity in the left upper lobe” 这句话,同时告诉了模型:位置(左上叶)、大小(small)、形状(round)、类型(GGO)。这比单纯给一张 mask 要丰富得多。
二、整体架构
LViT 不是简单地把 ViT 接在 U-Net 后面。它的设计可以拆成三条并行的信息流:
|
|
一句话总结:U-Net 做骨架,4 个尺度的 ViT 并行处理 token 并用文本增强,ViT 输出再反哺回 U-Net 的 skip connection。
三、关键模块拆解
3.1 图像编码器(标准 U-Net Encoder)
和经典 U-Net 一样,4 级下采样,每一级是 MaxPool2d + 2×(Conv3×3 + BN + ReLU):
|
|
3.2 ViT 多尺度分支
这是 LViT 的核心。4 个 VisionTransformer 实例分别处理不同分辨率:
| 层级 | patch_size | 输入分辨率 | embed_dim | token 数 |
|---|---|---|---|---|
| downVit | 16 | 224×224 | 64 | 196 |
| downVit1 | 8 | 112×112 | 128 | 196 |
| downVit2 | 4 | 56×56 | 256 | 196 |
| downVit3 | 2 | 28×28 | 512 | 196 |
每一层保证了相同的 token 数(196),但 embedding 维度随通道数翻倍。这样一来,浅层捕获细粒度纹理,深层捕获高级语义。
每个 ViT 实例包含:
- Patch Embedding:用 stride=kernel_size=patch_size 的 Conv2d 做 patch 投影
- Position Embedding:可学习的 1×196×embed_dim 参数
- 1 层 Transformer Block:LayerNorm → Multi-Head Attention(4 heads) → DropPath → LayerNorm → MLP(GELU, expand_ratio=4)
3.3 文本降维管道
文本先走 BERT-base-uncased 拿到 [B, 10, 768] 的 embedding,然后通过 4 个 Conv1d 逐级降维:
|
|
得到 text4→text1 四个多尺度文本特征,分别对应四个 ViT 层级。本意是浅层 ViT 用低维文本特征(粗粒度语义),深层 ViT 用高维文本特征(细粒度语义)。
3.4 文本-视觉融合
这是重点。在 VisionTransformer.forward() 的前向路径中:
|
|
CTBN3 是一个 Conv1d,把 10 个文本 token 投影到 196 个 patch token 空间,然后直接相加。这是一种最朴素的融合方式,相当于告诉 ViT:“在每个 patch 的视觉特征上,叠加一句全局文本描述的投影”。
发现的问题:条件
self.dim == 64意味着只在 downVit(第一层)完成融合,downVit1/2/3 虽然传入了 text2/3/4,但实际上从未使用。多尺度文本语义没有被充分利用。
3.5 PixLevelModule(像素级注意力)
这个模块放在 U-Net 的 skip connection 上,给跳跃连接的特征图做空间注意力加权:
|
|
本质是一个轻量级的可学习空间注意力,用 avg/max/combined 三种统计量来自动学习哪些空间位置更重要。它被用在两个地方:
- UpblockAttention:上采样后对 skip 特征做注意力加权再拼接
- Skip connection 增强:ViT 输出的 token 经过 reconstruct 恢复成空间图后,也经过 PixLevelModule 再加回原始 skip
3.6 Reconstruct:Token 回到空间
ViT 工作在 token 空间 [B, N, C],但 U-Net 解码器需要空间特征图。Reconstruct 模块完成这个转换:
|
|
四、损失函数
训练使用 WeightedDiceBCE,将 Dice Loss 和 BCE Loss 线性组合:
$$ \mathcal{L} = 0.5 \times \mathcal{L}_{\text{Dice}} + 0.5 \times \mathcal{L}_{\text{BCE}} $$WeightedDiceLoss:每个样本独立计算 Dice,正负样本按 [0.5, 0.5] 加权:
其中 $w = t \cdot (w_{\text{pos}} - w_{\text{neg}}) + w_{\text{neg}}$,即正样本像素权重为 0.5,负样本也为 0.5。
WeightedBCE:正负样本分别归一化:
$$ \mathcal{L}_{\text{BCE}} = w_{\text{pos}} \cdot \frac{\sum_{\text{pos}} \text{BCE}}{N_{\text{pos}}} + w_{\text{neg}} \cdot \frac{\sum_{\text{neg}} \text{BCE}}{N_{\text{neg}}} $$权重为 [0.4, 0.6],对负样本的 BCE 给略高的权重——这是为了应对医学图像中正样本(病灶)通常远少于负样本(背景)的类不平衡问题。
五、训练策略
| 配置项 | 设定 |
|---|---|
| 优化器 | Adam, lr=3e-4 (Covid19) / 1e-3 (MoNuSeg) |
| LR 调度 | CosineAnnealingWarmRestarts (T_0=10, eta_min=1e-4) |
| Batch Size | 2(论文建议 2 优于 4) |
| Epochs | 2000 |
| Early Stopping | patience=50, 监控 val Dice |
| 输入尺寸 | 224×224 |
| 数据增强 | 随机 90° 旋转 + 随机翻转 + 随机旋转 ±20° |
还支持 LViT_pretrain 模式:先在 MoNuSeg 上预训练 U-Net 的卷积部分,再迁移到目标数据集。
六、实验结果
| 数据集 | U-Net | LViT-T | 提升 |
|---|---|---|---|
| QaTa-COV19 | 79.02% | 83.66% | +4.64 |
| MosMedData+ | 64.60% | 74.57% | +9.97 |
| MoNuSeg | 76.45% | 80.36% | +3.91 |
| BKAI-Poly | - | 92.07% | - |
| ESO-CT | - | 68.27% | - |
在三个主要数据集上稳定超越 U-Net baseline 3-10 个百分点。MosMedData+ 数据集上提升最大,可能是因为该数据集病灶形态多变,文本先验的辅助作用更明显。
七、代码走读中的发现
通读 LViT 源码 时,我发现了几个值得注意的地方:
7.1 text2/3/4 是死代码
VisionTransformer.forward() 中文本融合的条件是 if self.dim == 64,只有第一层 ViT 满足。text2(128维)、text3(256维)、text4(512维) 虽然传入但从未使用。一个潜在的改进是为各层加上对应的 CTBN 投影——让文本的多尺度语义真正发挥价值。
7.2 LR Scheduler 从未生效
train_model.py 第 160 行将 lr_scheduler 参数传为 None:
|
|
而真正传入 scheduler 的是验证阶段(第 167 行)。查看 Train_one_epoch.py 第 129 行:
|
|
scheduler 在验证时才 step,而且 CosineAnnealingWarmRestarts 的标准用法是 per-batch step(见其 docstring),per-epoch step 没有发挥 warm restart 的效果。
7.3 BERT 编码无缓存
TextEmbedder.__call__() 每次 __getitem__ 都跑一次完整的 BERT 前向传播。同一个 epoch 内同一张图的文本被反复编码,这是不必要的开销。预计算并缓存可以显著加速数据加载。
7.4 文本融合方式较朴素
当前是 x = x + CTBN3(text)(token 空间直接相加),没有学习文本 token 和图像 token 之间的对应关系。升级为 cross-attention 让模型学会"根据文本语义选择性地关注图像区域"是一个自然的改进方向。
八、值得思考的问题
-
文本信息到底起了多大作用? 论文没有提供 ablation study 去掉文本只用 ViT 分支的结果。如果文本嵌入本身就接近随机初始化,模型可能只是从 ViT 多尺度结构中受益。
-
10 个 token 是否足够? BERT 的 max_length 设为 10,而放射科报告通常更长。截断可能导致关键信息丢失。
-
文本标注成本:给每张训练图像配一句英文描述,不是所有医学影像数据集都能做到。这个依赖性限制了 LViT 的应用范围。
总结
LViT 是一个设计精巧的多模态医学分割架构,核心贡献在于:
- 提出了 U-Net + 多尺度 ViT + 文本增强的三流架构
- 设计了 PixLevelModule 做像素级注意力来增强 skip connection
- 实验证明文本先验对分割性能有稳定提升
也有一些可以改进的地方:多尺度文本融合未完全实现、LR 调度器有 bug、文本融合方式可以更精细。如果你在做医学图像分割相关的研究,LViT 是一个值得精读和魔改的 baseline。
参考文献:Li Z, Li Y, Li Q, et al. LViT: Language meets Vision Transformer in Medical Image Segmentation. IEEE TMI, 2023.