一篇文章搞懂Segment Anything(SAM)


本文首发于公众号“CVTALK”。

论文链接:https://arxiv.org/abs/2304.02643

代码链接:https://github.com/lllyasviel/ControlNet

Demo链接:https://segment-anything.com/demo

SAM从任务、模型、数据三部分展开写作,和模型的创新比较起来,任务定义和数据的工作更加出彩,官网也给出了demo,能直观感受SAM的效果,这篇blog也会围绕这几部分展开。

demo

demo中有开放point, box, everything三种方式。由于text prompt效果不太稳定,demo和代码中都没有该部分。

  • 鼠标悬停: 显示的是悬停位置的分割结果,例如下图中将鼠标放到手的位置.
  • 点击: 分割包含该点的物体,会按最小分割的结果展示出来,如果想分割的物体大于展示的结果,可以在物体的其他部分也点击下。

  • box: 框定一个box,分割box中的物体
  • everything: 将图片中所有物体的分割都展示出来

任务

任务的设计灵感来自于NLP领域,例如NLP中可以通过预测next token作为预训练任务,而在下游任务中可以使用prompt engineering做应用。因此,为了建立分割的基础模型,任务的设计目标是也需要具有类似的能力。
这里作者扩展了下NLP里prompt在图像分割里的用法, prompt可以是以下几种类型:

  • point
  • box
  • mask
  • 任意格式的文本

为了支持以下的几种输入prompt格式,要求模型能够区分具有混淆意义的prompt,例如下图中,一个point的prompt可能有多种分割方式.这多种分割方式对于模型来说都是有效的。

预训练: 将上面提到的多种sequence的prompt告诉模型,训练目标是让模型输出对应promt的分割结果,并且期望模型输出的结果和GT尽可能一致。区别于之前的交互式分割算法,SAM基本能治通过一次交互就能得到很合理的分割结果。要达到这个目的,需要设计非常独特的模型结构和loss。

zero-shot transfer:需要模型对任何prompt,得到合适的分割结果。例如,如果要做实例分割,可以把检测得到的box作为prompt,SAM就能去做实例分割

related tasks: 分割里有很多子任务,例如边缘分割,语义分割等,SAM能完成所有已知的分割任务和还没有作为一个方向的分割任务。之前已经有类似的可以做多种分割的模型(solo), 但是这些模型有多个子子输出,然后做排列组合可以得到多种分割结果。而SAM通过prompt将多个分割任务合并在一起。

总而言之,作者是希望SAM能够分割一切,并且能相CLIP一样,能应用到最开始没有想到的领域。

模型

模型的结构如上图所示. prompt会经过prompt encoder, 图像会经过image encoder。然后将两部分embedding经过一个轻量化的mask decoder得到融合后的特征。encoder部分使用的都是已有模型,decoder使用transformer。这部分论文中介绍的相对比较少,下面会结合代码一起梳理下:

  • image encoder: 使用的是用ViT走位backbone的MAE模型。在交互式分割的展示中,image encoder只会运行一次。在实验中,分别有用到ViT-H, ViT-L, ViT-B三种大小的模型作为image encoder。代码如下,build_sam#L47

    1
    2
    3
    4
    5
    6
    sam_model_registry = {
    "default": build_sam_vit_h,
    "vit_h": build_sam_vit_h,
    "vit_l": build_sam_vit_l,
    "vit_b": build_sam_vit_b,
    }
  • prompt encoder: prompt总共有point,box, mask, text四种,会将其分为三类。pint和box可以作为一类使用position encodings, text可以使用CLIP作为encoder, 而mask是一种密集型的prompt,可以使用卷积作为encoder.prompt_encoder.py#LL128C5-L128C5 prompt_encoder的代码如下所示,其中用position embedding分别实现了point和box query两种稀疏embedding,用卷积实现了mask query密集embedding.,

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    def forward(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Embeds different types of prompts, returning both sparse and dense
    embeddings.

    Arguments:
    points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
    and labels to embed.
    boxes (torch.Tensor or none): boxes to embed
    masks (torch.Tensor or none): masks to embed

    Returns:
    torch.Tensor: sparse embeddings for the points and boxes, with shape
    BxNx(embed_dim), where N is determined by the number of input points
    and boxes.
    torch.Tensor: dense embeddings for the masks, in the shape
    Bx(embed_dim)x(embed_H)x(embed_W)
    """
    bs = self._get_batch_size(points, boxes, masks)
    sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
    if points is not None:
    coords, labels = points
    point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) # position embedding
    sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
    if boxes is not None:
    box_embeddings = self._embed_boxes(boxes) # position embedding
    sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

    if masks is not None:
    dense_embeddings = self._embed_masks(masks) # conv embedding
    else:
    dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
    bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
    )

    return sparse_embeddings, dense_embeddings
  • mask decoder:: 使用一个transformer将image embedding和prompt embedding做双向的cross-attention;并且也有prompt embedding的self-attention。也有MLP和linear classifier分类分割区域。mask decoder, transformer.py#L151这里的queries是query embedding,keys是image embedding,query_pe和queries一样,key_pe是需要加到image embedding上的位置编码。query embedding会经过self attention。query embedding和image embedding会做双向的cross-attention, 具体实现方式是如上代码所示,image embedding会作为query,query embedding会作为key和value;同样的,query embedding会作为query,image embedding会作为key和value。

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    def forward(
    self,
    image_embedding: Tensor,
    image_pe: Tensor,
    point_embedding: Tensor,
    ) -> Tuple[Tensor, Tensor]:
    """
    Args:
    image_embedding (torch.Tensor): image to attend to. Should be shape
    B x embedding_dim x h x w for any h and w.
    image_pe (torch.Tensor): the positional encoding to add to the image. Must
    have the same shape as image_embedding.
    point_embedding (torch.Tensor): the embedding to add to the query points.
    Must have shape B x N_points x embedding_dim for any N_points.

    Returns:
    torch.Tensor: the processed point_embedding
    torch.Tensor: the processed image_embedding
    """
    # BxCxHxW -> BxHWxC == B x N_image_tokens x C
    bs, c, h, w = image_embedding.shape
    image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
    image_pe = image_pe.flatten(2).permute(0, 2, 1)

    # Prepare queries
    queries = point_embedding
    keys = image_embedding

    # Apply transformer blocks and final layernorm
    for layer in self.layers:
    queries, keys = layer(
    queries=queries,
    keys=keys,
    query_pe=point_embedding,
    key_pe=image_pe,
    )

    # Apply the final attention layer from the points to the image
    q = queries + point_embedding
    k = keys + image_pe
    attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
    queries = queries + attn_out
    queries = self.norm_final_attn(queries)

    return queries, keys
  • 解决混淆的输入: 对于一个prompt,模型会输出3个mask,实际上也可以输出更多的分割结果,3个可以看作一个物体的整体、部分、子部分,基本能满足大多数情况。使用IOU的方式,排序mask。在反向传播时,参与计算的只有loss最小的mask相关的参数.

  • 高效: 这里主要指的是prompt encodermask decoder。在web浏览器上,CPU计算只用约50ms

  • loss和训练细节: 主要使用的是focal loss和dice loss。每一个mask,会随机产生11种prompt与之配对。

数据

数据引擎

不像CLIP中图像文本对通过互联网容易获取,分割的数据获取成本巨大。SAM开源了一个10亿张图片的分割数据集。在SAM中设计了一个数据引擎用于获取分割的数据,数据引擎主要分为以下三部分:

  • 辅助标注: 简单来说就是用可以获取到的开源分割数据训练一个初始的SAM模型V0版本,再用V0在没有分割标注的数据上生成预标注,人工check模型的结果并作修改和确认。得到新的数据后,再将新的数据加入到训练集重新训练SAM得到V1版本,再循环标注数据和迭代模型。总共进行6次训练。开始的时候数据集比较少,使用的ViT-B模型,最终会使用ViT-H模型。 这里面还有一些效率提升的数据,例如随着模型的迭代,每个mask的标注耗时从34s到14s。SAM模型在每张图片上生成的mask从20到44个。在该阶段数据集最终有12万张图片,430万个mask

  • 半自动化标注: 通过第一阶段后,已经有一个不错的SAM模型能生成分割结果。半自动化标注的目的是增加mask的多样性。具体做法是训练一个检测模型,用于检测SAM生成的mask结果是否可信,只保留可信的mask结果,然后将图片给人工标注。人工标注会在可信的mask基础上标注出其他的分割框。经过5次的迭代后,数据集新增了18万张图片,590万mask。

自动标注: 经过前面两个阶段后,SAM有了较好的结果,能分割出图片中的目标,并且对于混淆的prompt也有了较好的输出。这个模型可以自动的对一些图片做标注。自动标注的时候需要有一些筛选策略,模型输出的结果可能还是会出现一些错误。主要有以下三种方式做筛选

  • SAM模型有一个IOU prediction的模块能输出mask的confidence,如下图所示
  • stable mask的判断,具体的方法是在得到分割结果前对logit加正向和负向的扰动,如果两次扰动生成的分割结果IOU大于0.95,则认为生成的mask是可靠的

  • NMS过滤掉重复的mask

数据质量

图像: 包含11M高分辨率(33004950)的图像,其他的一些开源数据集,例如COCO分辨率较低(480640)
Mask: 包含1.1B的mask,99.1%都是模型生成的。作者实验了下,只使用模型生成的mask和即使用模型生成也使用人工标注的mask,模型的效果是相当的。因此发布的数据集里只包含模型生成的mask
Mask 质量: 抽取了一部分mask数据做人工的精标,精标前后有94%的mask具有90%以上的IOU。而其他的一些开源数据集只有85-91%的IOU

下面也从mask的数量,每种mask尺寸的占比和mask占外接矩形比例等多方面和其他数据集做了对比

数据来源分布

不同性别,肤色,年龄人群分割效果的差异对比

zero-short Transfer实验

评估的数据集都是SAM模型训练时的不同,并且包含水下,第一视角等没有在SAM中出现过场景的图片

point mask

这里对比的是用point作为prompt对比分割的结果, 在绝大部分数据集中都优于RITM(当前的SOTA)

边缘检测

SAM在训练的时候就是采用的包含point prompt的方式,作者这里还对比了一些在训练时没有包含的方式,边界检测就是其中一种。SAM在使用边界检测时,使用方式是在图片上铺上16*16均匀的point prompt,每个prompt产生3个mask,再经过NMS后。通过Sobel filtering得到边缘检测的结果。SAM的结果倾向于提取更丰富的边缘,因此在指标上recall和专门做边缘检测的模型相当,precision会低些。

目标检测

分割的结果取bbox,就能做目标检测了.整体指标低于ViTDet,但是在中等常见和不太常见的目标上效果优于ViTDet

实例分割

先用一个目标检测算法,用目标检测得到的box作为prompt输入到SAM,就可以做实例分割了。实验的结果分为了定量(用测试集的GT)和定性(人来评判好坏)两种。定量的指标不如BiTDet—H,定性的指标SAM优于ViTDet。作者给出的解释是COCO数据集标注效果一般(在人看来甚至不如SAM和ViTDet模型输出的结果),因此ViTDet在COCO上做训练时拟合到了一些错误的偏差,但错误的偏差和标注相似,因此定量的指标不如ViTDet

Text to Mask

这里指的是用文本作为prompt,然后分割出文本提到的目标。作者在训练的时候取的是图片中目标尺寸大于100*100的目标,用CLIP提取image embedding(text embedding也行,因为CLIP的image embedding和text embedding是对齐的),作为prompt encoder模块的输出,用于训练SAM模型。这一部分没有和其他方法对比,也由于效果不太稳定,在官方的demo中没有展示

消融实验

有以下的结论:
左边的图,数据来源的影响:

  • 加入半自动标注的数据和自动标注的数据性能都有很大的提升
  • 只用模型生成的数据与额外加上人工标注的数据差异不大

中间的图, 数据量的影响:

  • 数据量从0.1M到1M,模型性能提升很大
  • 数据量从1M到11M,模型性能变化不明显,实际使用中1M差不多足够

右边的图, image encoder的影响:

  • ViT-B到ViT-L提升很大
  • ViT-L到ViT-H提升一般,实际使用ViT-L足够

总结

SAM的热度也非常高,同样作为FB的工作,SAM仅仅放出来两个月,github上star的数量已经超过了detectron2三年的总和。SAM的期望是能将该模型作为图像领域的基础模型(foundation model),像CLIP那样能在各个领域大放异常,或者像GPT一样能统一NLP领域。SAM也确实在很多场景得到了应用,例如开源的SD中也融合了SAM,可以做很多有趣的应用,例如从假人模特身上用SAM得到衣服的mask,再结合ControlNet,就可以生成不同的人穿着同样的衣服。

最开始自媒体宣传的文章也是《CV领域不存在了》,《CV界的GPT3》类似的标题,SAM确实是在统一上迈出了很大的一步,但实际上CV领域的统一还有很多挑战。NLP领域中的Bert用完型填空和GPT预测下一个token的预训练在非常多的任务上表现了很好的泛化性,甚至在一些没有训练过的任务上能取得比一些专家模型更好的效果。

  • 任务和数据上的不统一,CV领域的分类是输出类别,检测输出bbox,分割输出mask。虽然个别任务可以复用,但是整体缺乏一个通用的任务。任务上的不统一,数据上也很难做到统一,分类的任务有很多数据,但是检测和分割的数据就要少非常多,并且标注成本巨大。单纯训练分类作为backbone也很难解决其他任务,检测和分割的算法依然需要做大量的优化

  • CV领域的任务缺乏孕育大模型的土壤,CV任务一直在考虑模型的计算量,显存占用。如果将每个像素看作一个token,一张512*512的图片就有26万个token。如果transformer最开始出现在CV领域,面临的问题是显存和计算量都比resnet差,并且效果也远不如resnet。如果没有transformer极大的促进了NLP领域的发展,CV领域可能也不会重新思考transfomer能增大感受野,能有更好的泛化能力。

  • 还没有找到CV领域【高维】的任务。NLP领域的完形填空和对话确实是一种很高维的任务。模型能完成这些任务,一些NER或者RE之类的底层任务也能很好的被解决。目前CV领域有一些尝试做foundation model,例如对比学习或者像SAM,在一些任务上表现了不错的泛化性,可能是这些方法能统一其他任务,但目前的发展还不太够,也可能是其他一些还没出现/发展起来的任务。但这种【高维】的任务一定能通过一些方式降维解决目前几乎所有的CV基础任务。

-------------本文结束感谢您的阅读-------------