使用Beit提取图片特征
当今,计算机视觉技术在各个领域中得到了广泛应用,其中图片特征提取是计算机视觉任务中的重要一环。而在2021年的研究中,微软推出了一种预训练模型——Beit,它在提取图片特征方面表现出了卓越的性能。本文将介绍Beit模型并探讨如何使用它来提取图片特征。
一、图片Resize
在计算机视觉任务中,不同大小的图片通常需要被处理成相同的大小,以便进行后续的特征提取或者模型训练。此外,在某些情况下,如移动端或者网络传输时,更小的图片也可以提高性能和效率。因此,图片Resize是计算机视觉任务中不可或缺的一个步骤。
在Python中,我们可以使用Pillow(Python Imaging Library,也叫做PIL)库来进行图片的Resize操作。在使用Beit提取特征时,我们先将图片Resize未[3, 224, 224]的大小。代码如下:
import torch
import torchvision
from torch import nn
from torchvision import transforms
# Define image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load image
image = Image.open("image.jpg")
# Apply transform to image
image = transform(image)
二、初始化模型并加载参数
在Python中,我们可以使用PyTorch库来初始化Beit模型,并加载预训练参数。具体操作如下:
from transformers import BeitForImageClassification
from PIL import Image
import requests
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
model = BeitForImageClassification.from_pretrained('microsoft/beit-large-patch16-224-pt22k-ft22k')
三、去除原模型分类器
BeitForImageClassification原模型最后一层是一个将1024维映射到21841维的MLP,以用于预测图片和词典中的什么token最相似,并给出结果。那么我们如果要提取图片特征,也就是倒数第二层中输出的1024维的特征,我们则需要将最后一层分类器去除。步骤如下:
观察模型:
print(model)
去除最后一层:
import torch
new_model = torch.nn.Sequential( *( list(model.children())[:-1] ) )
再次观察:
print(new_model)
那么,我们便得到了一个Beit提取特征模型。使用如下:
feature = new_model(image)