使用Beit提取图片特征

当今,计算机视觉技术在各个领域中得到了广泛应用,其中图片特征提取是计算机视觉任务中的重要一环。而在2021年的研究中,微软推出了一种预训练模型——Beit,它在提取图片特征方面表现出了卓越的性能。本文将介绍Beit模型并探讨如何使用它来提取图片特征。

一、图片Resize

在计算机视觉任务中,不同大小的图片通常需要被处理成相同的大小,以便进行后续的特征提取或者模型训练。此外,在某些情况下,如移动端或者网络传输时,更小的图片也可以提高性能和效率。因此,图片Resize是计算机视觉任务中不可或缺的一个步骤。

在Python中,我们可以使用Pillow(Python Imaging Library,也叫做PIL)库来进行图片的Resize操作。在使用Beit提取特征时,我们先将图片Resize未[3, 224, 224]的大小。代码如下:

import torch
import torchvision
from torch import nn
from torchvision import transforms

# Define image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# Load image
image = Image.open("image.jpg")

# Apply transform to image
image = transform(image)

二、初始化模型并加载参数

在Python中,我们可以使用PyTorch库来初始化Beit模型,并加载预训练参数。具体操作如下:

from transformers import BeitForImageClassification
from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

model = BeitForImageClassification.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')

三、去除原模型分类器

BeitForImageClassification原模型最后一层是一个将1024维映射到21841维的MLP,以用于预测图片和词典中的什么token最相似,并给出结果。那么我们如果要提取图片特征,也就是倒数第二层中输出的1024维的特征,我们则需要将最后一层分类器去除。步骤如下:

观察模型:

print(model)

去除最后一层:

import torch
new_model = torch.nn.Sequential( *( list(model.children())[:-1] ) )

再次观察:

print(new_model)

那么,我们便得到了一个Beit提取特征模型。使用如下:

feature = new_model(image)

发表回复