《Unbiased Scene Graph Generation from Biased Training》论文笔记
任务定义
场景图生成: 描述目标检测的物体之间,具有怎样的关系
之前算法存在的问题
- 数据集中关系词存在严重的偏见,原因有以下几点:
- 标注时,倾向于简单的关系
- 日常生活中确实有些事物的关联性比较多
- 语法习惯的问题
- 往往通过union box和两个物体的类别就预测了两个物体的关系,几乎没有使用visual feature,也就预测不出有意义的关系
Methods
- 为了让模型预测更有意义的关系,用了一个causal inference中的概念,即Total Direct Effect(TDE)来取代单纯的网络log-likelihood。在预测关系时更关注visual feature.
- 提出了新的评测方法mR@K:把所有谓语类别的Recall单独计算,然后求均值,这样所有类别就一样重要了
Biased Training Models in Causal Graph
节点I: Input Image&Backbone.Backbone部分使用Faster rcnn预训练好的模型,并frozen bockbone的参数。输出检测目标的bounding boxes和图像的特征图.
Link I->X:目标的特征,通过ROI Align提取目标对应的特征,获取目标粗略的分类结果.和MOTIFS和VCTree一样,使用以下方式,编码视觉上下文特征.
MOTIFS中使用双向LSTM,VCTree中使用双向TreeLSTM,早期工作如VTransE中使用全连接层
节点X:目标特征。获取一组目标的特征
Link X-Z: 获取对应目标fine-tuned的类别,从解码:
节点Z:目标类别,one-hot的向量
Link X->Y: SGG的目标特征输入
Link:Z->Y:SGG的目标类别输入
Link:I->Y:SGG的视觉特征输入
节点Y:输出关系词汇
Training loss:使用交叉熵损失预测label,为了避免预测Y只使用单一输入的信息,尤其是只使用Z的信息,进一步 使用auxiliary cross-entropy losses, 让每一个分支分别预测y
Unbiased Prediction by Causal Effects
机器学习中常见的解决长尾问题的方法:
- 数据增强/重新采样
- 对数据平衡改进的loss
- 从无偏见中分离出有偏见的部分
与上述方法的区别是不需要额外训练或层来建模偏差,通过构建两种因果图将原有模型和偏差分离开。
Origin&Intervention&Counterfactual
- ntervention:清除因果图中某个节点的输入,并将其置为某个值,公式为:$do(X= \tilde x)$.某节点被干预后,需要该节点输入的其他节点也会受影响
- Counterfactual:让某个节点被干预后,其他需要输入的节点还假设该节点未被干预
总结: Counterfactual图实际上抹除了因果图像中object feature。只用image+object label预测两个目标间的关系。
Total Direct Effect (TDE)
根据两个因果图:
- 原始因果图
- Counterfactual因果图(可以认为是偏见)
消除偏见:
Experiments
Ablation Studies
对比了几种常用的优化长尾问题的方法
- Focal
- Reweight
- Resample
- X2Y: 直接通过X的输出预测Y
- X2Y-Tr:切断其他分支的联系,只使用X预测Y
- TE:
- NIE:
- TDE
欢迎关注我的公众号!
-------------本文结束感谢您的阅读-------------