LViT 论文精读:当语言遇上视觉 Transformer,医学图像分割的多模态之路

深入解析 LViT(Language meets Vision Transformer)架构:U-Net + 多尺度 ViT + BERT 文本融合,含完整数据流、关键代码走读、已知问题与改进方向。

写在前面

最近在看医学图像分割的多模态方法,读到 IEEE TMI 2023 的 LViT(Language meets Vision Transformer in Medical Image Segmentation),觉得它的设计思路很有意思——把放射科医生的文本描述(比如"边界清晰的圆形病灶")直接用 BERT 编码后注入分割网络。

这篇博客是我对 LViT 的完整走读,包含架构拆解、数据流追踪、关键代码分析和一些发现的问题。论文链接在 arXiv:2206.14718

一、动机:文本能带来什么?

纯视觉的医学图像分割有个痛点:病灶形态千变万化,单靠像素级标注训练出来的模型泛化能力有限。LViT 的核心假设是:

放射科医生的文本描述包含了高级语义先验(位置、形状、大小、边界特征),这些信息可以作为视觉特征的辅助,引导模型关注正确的区域。

比如 “a small round ground-glass opacity in the left upper lobe” 这句话,同时告诉了模型:位置(左上叶)、大小(small)、形状(round)、类型(GGO)。这比单纯给一张 mask 要丰富得多。

二、整体架构

LViT 不是简单地把 ViT 接在 U-Net 后面。它的设计可以拆成三条并行的信息流:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
输入: 图像 [B,3,224,224]  +  文本 [B,10,768]BERT 预编码)

     ┌───── 图像编码器 ─────┐    ┌── 文本降维管道 ──┐
      inc  x1 [64,224]        Conv1d 768512    text4
      down1  x2 [128,112]     Conv1d 512256    text3
      down2  x3 [256,56]      Conv1d 256128    text2
      down3  x4 [512,28]      Conv1d 12864     text1
      down4  x5 [512,14]     └─────────────────┘
     └──────────────────────┘           
                                        
                  ┌────────── ViT 多尺度分支 ──────────┐
                   downVit(x1, text1)  y1 [196,64]    文本只在这一层融合!
                   downVit1(x2, y1)   y2 [196,128]  
                   downVit2(x3, y2)   y3 [196,256]  
                   downVit3(x4, y3)   y4 [196,512]  
                                                      
                   upVit3~upVit: reconstruct 路径     
                  └────────────────────────────────────┘
                                        
                                        
     ┌─────────── reconstruct ──────────────┐
      token  空间特征图,增强 skip        
      x1~x4  ViT 分支增强后进入解码器    
     └──────────────────────────────────────┘
              
              
     ┌─────────── U-Net 解码器 ────────────┐
      up4(x5,x4)  up3  up2  up1        
      UpblockAttention: PixLevel 注意力   
       outc  Sigmoid  [B,1,224,224]    
     └──────────────────────────────────────┘

一句话总结:U-Net 做骨架,4 个尺度的 ViT 并行处理 token 并用文本增强,ViT 输出再反哺回 U-Net 的 skip connection。

三、关键模块拆解

3.1 图像编码器(标准 U-Net Encoder)

和经典 U-Net 一样,4 级下采样,每一级是 MaxPool2d + 2×(Conv3×3 + BN + ReLU)

1
2
3
4
5
6
x [3,224,224]
  → inc (ConvBNReLU, 3→64)         → x1 [64,224,224]
  → down1 (MaxPool + 2×ConvBNReLU)  → x2 [128,112,112]
  → down2                          → x3 [256,56,56]
  → down3                          → x4 [512,28,28]
  → down4                          → x5 [512,14,14]

3.2 ViT 多尺度分支

这是 LViT 的核心。4 个 VisionTransformer 实例分别处理不同分辨率:

层级 patch_size 输入分辨率 embed_dim token 数
downVit 16 224×224 64 196
downVit1 8 112×112 128 196
downVit2 4 56×56 256 196
downVit3 2 28×28 512 196

每一层保证了相同的 token 数(196),但 embedding 维度随通道数翻倍。这样一来,浅层捕获细粒度纹理,深层捕获高级语义。

每个 ViT 实例包含:

  • Patch Embedding:用 stride=kernel_size=patch_size 的 Conv2d 做 patch 投影
  • Position Embedding:可学习的 1×196×embed_dim 参数
  • 1 层 Transformer Block:LayerNorm → Multi-Head Attention(4 heads) → DropPath → LayerNorm → MLP(GELU, expand_ratio=4)

3.3 文本降维管道

文本先走 BERT-base-uncased 拿到 [B, 10, 768] 的 embedding,然后通过 4 个 Conv1d 逐级降维:

1
2
3
4
self.text_module4 = Conv1d(in=768, out=512, kernel=3, padding=1)
self.text_module3 = Conv1d(in=512, out=256, kernel=3, padding=1)
self.text_module2 = Conv1d(in=256, out=128, kernel=3, padding=1)
self.text_module1 = Conv1d(in=128, out=64,  kernel=3, padding=1)

得到 text4→text1 四个多尺度文本特征,分别对应四个 ViT 层级。本意是浅层 ViT 用低维文本特征(粗粒度语义),深层 ViT 用高维文本特征(细粒度语义)。

3.4 文本-视觉融合

这是重点。在 VisionTransformer.forward() 的前向路径中:

1
2
3
4
5
6
def forward(self, x, skip_x, text, reconstruct=False):
    if not reconstruct:                    # down 路径
        x = self.embeddings(x)            # 图像 patch embedding
        if self.dim == 64:                # ⚠️ 仅第一层生效!
            x = x + self.CTBN3(text)      # CTBN3: Conv1d(10→196)
        x = self.Encoder_blocks(x)        # Transformer

CTBN3 是一个 Conv1d,把 10 个文本 token 投影到 196 个 patch token 空间,然后直接相加。这是一种最朴素的融合方式,相当于告诉 ViT:“在每个 patch 的视觉特征上,叠加一句全局文本描述的投影”。

发现的问题:条件 self.dim == 64 意味着只在 downVit(第一层)完成融合,downVit1/2/3 虽然传入了 text2/3/4,但实际上从未使用。多尺度文本语义没有被充分利用。

3.5 PixLevelModule(像素级注意力)

这个模块放在 U-Net 的 skip connection 上,给跳跃连接的特征图做空间注意力加权:

1
2
3
4
5
6
7
8
x [B, C, H, W]
  ├→ conv_avg(1×1) → ReLU → channel_mean  → x_avg [B, 1, H, W]
  ├→ conv_max(1×1) → ReLU → channel_max   → x_max [B, 1, H, W]
  └→ x_out = x_avg + x_max

concat(x_avg, x_max, x_out) → [B, 3, H, W]
  → transpose → Linear(3→6→1) → Sigmoid  → attention_map [B, 1, H, W]
  → y = attention_map × x

本质是一个轻量级的可学习空间注意力,用 avg/max/combined 三种统计量来自动学习哪些空间位置更重要。它被用在两个地方:

  1. UpblockAttention:上采样后对 skip 特征做注意力加权再拼接
  2. Skip connection 增强:ViT 输出的 token 经过 reconstruct 恢复成空间图后,也经过 PixLevelModule 再加回原始 skip

3.6 Reconstruct:Token 回到空间

ViT 工作在 token 空间 [B, N, C],但 U-Net 解码器需要空间特征图。Reconstruct 模块完成这个转换:

1
2
[B, N, C] → permute → [B, C, N] → reshape → [B, C, sqrt(N), sqrt(N)]
  → Upsample(scale_factor) → Conv1×1 + BN + ReLU

四、损失函数

训练使用 WeightedDiceBCE,将 Dice Loss 和 BCE Loss 线性组合:

$$ \mathcal{L} = 0.5 \times \mathcal{L}_{\text{Dice}} + 0.5 \times \mathcal{L}_{\text{BCE}} $$

WeightedDiceLoss:每个样本独立计算 Dice,正负样本按 [0.5, 0.5] 加权:

$$ \mathcal{L}_{\text{Dice}} = 1 - \frac{2\sum (w \cdot p \cdot t) + \epsilon}{\sum (w \cdot p^2) + \sum (w \cdot t^2) + \epsilon} $$

其中 $w = t \cdot (w_{\text{pos}} - w_{\text{neg}}) + w_{\text{neg}}$,即正样本像素权重为 0.5,负样本也为 0.5。

WeightedBCE:正负样本分别归一化:

$$ \mathcal{L}_{\text{BCE}} = w_{\text{pos}} \cdot \frac{\sum_{\text{pos}} \text{BCE}}{N_{\text{pos}}} + w_{\text{neg}} \cdot \frac{\sum_{\text{neg}} \text{BCE}}{N_{\text{neg}}} $$

权重为 [0.4, 0.6],对负样本的 BCE 给略高的权重——这是为了应对医学图像中正样本(病灶)通常远少于负样本(背景)的类不平衡问题。

五、训练策略

配置项 设定
优化器 Adam, lr=3e-4 (Covid19) / 1e-3 (MoNuSeg)
LR 调度 CosineAnnealingWarmRestarts (T_0=10, eta_min=1e-4)
Batch Size 2(论文建议 2 优于 4)
Epochs 2000
Early Stopping patience=50, 监控 val Dice
输入尺寸 224×224
数据增强 随机 90° 旋转 + 随机翻转 + 随机旋转 ±20°

还支持 LViT_pretrain 模式:先在 MoNuSeg 上预训练 U-Net 的卷积部分,再迁移到目标数据集。

六、实验结果

数据集 U-Net LViT-T 提升
QaTa-COV19 79.02% 83.66% +4.64
MosMedData+ 64.60% 74.57% +9.97
MoNuSeg 76.45% 80.36% +3.91
BKAI-Poly - 92.07% -
ESO-CT - 68.27% -

在三个主要数据集上稳定超越 U-Net baseline 3-10 个百分点。MosMedData+ 数据集上提升最大,可能是因为该数据集病灶形态多变,文本先验的辅助作用更明显。

七、代码走读中的发现

通读 LViT 源码 时,我发现了几个值得注意的地方:

7.1 text2/3/4 是死代码

VisionTransformer.forward() 中文本融合的条件是 if self.dim == 64,只有第一层 ViT 满足。text2(128维)、text3(256维)、text4(512维) 虽然传入但从未使用。一个潜在的改进是为各层加上对应的 CTBN 投影——让文本的多尺度语义真正发挥价值。

7.2 LR Scheduler 从未生效

train_model.py 第 160 行将 lr_scheduler 参数传为 None

1
2
train_one_epoch(train_loader, model, criterion, optimizer,
                writer, epoch, None, model_type, logger)  # None!

而真正传入 scheduler 的是验证阶段(第 167 行)。查看 Train_one_epoch.py 第 129 行:

1
2
if lr_scheduler is not None:
    lr_scheduler.step()

scheduler 在验证时才 step,而且 CosineAnnealingWarmRestarts 的标准用法是 per-batch step(见其 docstring),per-epoch step 没有发挥 warm restart 的效果。

7.3 BERT 编码无缓存

TextEmbedder.__call__() 每次 __getitem__ 都跑一次完整的 BERT 前向传播。同一个 epoch 内同一张图的文本被反复编码,这是不必要的开销。预计算并缓存可以显著加速数据加载。

7.4 文本融合方式较朴素

当前是 x = x + CTBN3(text)(token 空间直接相加),没有学习文本 token 和图像 token 之间的对应关系。升级为 cross-attention 让模型学会"根据文本语义选择性地关注图像区域"是一个自然的改进方向。

八、值得思考的问题

  1. 文本信息到底起了多大作用? 论文没有提供 ablation study 去掉文本只用 ViT 分支的结果。如果文本嵌入本身就接近随机初始化,模型可能只是从 ViT 多尺度结构中受益。

  2. 10 个 token 是否足够? BERT 的 max_length 设为 10,而放射科报告通常更长。截断可能导致关键信息丢失。

  3. 文本标注成本:给每张训练图像配一句英文描述,不是所有医学影像数据集都能做到。这个依赖性限制了 LViT 的应用范围。

总结

LViT 是一个设计精巧的多模态医学分割架构,核心贡献在于:

  • 提出了 U-Net + 多尺度 ViT + 文本增强的三流架构
  • 设计了 PixLevelModule 做像素级注意力来增强 skip connection
  • 实验证明文本先验对分割性能有稳定提升

也有一些可以改进的地方:多尺度文本融合未完全实现、LR 调度器有 bug、文本融合方式可以更精细。如果你在做医学图像分割相关的研究,LViT 是一个值得精读和魔改的 baseline。


参考文献:Li Z, Li Y, Li Q, et al. LViT: Language meets Vision Transformer in Medical Image Segmentation. IEEE TMI, 2023.

使用 Hugo 构建
主题 StackJimmy 设计