pytorch实现用Resnet提取特征并保存为txt文件的方法
接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。
以下是提取一张jpg图像的特征的程序:
#-*-coding:utf-8-*-
importos.path
importtorch
importtorch.nnasnn
fromtorchvisionimportmodels,transforms
fromtorch.autogradimportVariable
importnumpyasnp
fromPILimportImage
features_dir='./features'
img_path="hymenoptera_data/train/ants/0013035.jpg"
file_name=img_path.split('/')[-1]
feature_path=os.path.join(features_dir,file_name+'.txt')
transform1=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()]
)
img=Image.open(img_path)
img1=transform1(img)
#resnet18=models.resnet18(pretrained=True)
resnet50_feature_extractor=models.resnet50(pretrained=True)
resnet50_feature_extractor.fc=nn.Linear(2048,2048)
torch.nn.init.eye(resnet50_feature_extractor.fc.weight)
forparaminresnet50_feature_extractor.parameters():
param.requires_grad=False
#resnet152=models.resnet152(pretrained=True)
#densenet201=models.densenet201(pretrained=True)
x=Variable(torch.unsqueeze(img1,dim=0).float(),requires_grad=False)
#y1=resnet18(x)
y=resnet50_feature_extractor(x)
y=y.data.numpy()
np.savetxt(feature_path,y,delimiter=',')
#y3=resnet152(x)
#y4=densenet201(x)
y_=np.loadtxt(feature_path,delimiter=',').reshape(1,2048)
以下是提取一个文件夹下所有jpg、jpeg图像的程序:
#-*-coding:utf-8-*- importos,torch,glob importnumpyasnp fromtorch.autogradimportVariable fromPILimportImage fromtorchvisionimportmodels,transforms importtorch.nnasnn importshutil data_dir='./hymenoptera_data' features_dir='./features' shutil.copytree(data_dir,os.path.join(features_dir,data_dir[2:])) defextractor(img_path,saved_path,net,use_gpu): transform=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()] ) img=Image.open(img_path) img=transform(img) x=Variable(torch.unsqueeze(img,dim=0).float(),requires_grad=False) ifuse_gpu: x=x.cuda() net=net.cuda() y=net(x).cpu() y=y.data.numpy() np.savetxt(saved_path,y,delimiter=',') if__name__=='__main__': extensions=['jpg','jpeg','JPG','JPEG'] files_list=[] sub_dirs=[x[0]forxinos.walk(data_dir)] sub_dirs=sub_dirs[1:] forsub_dirinsub_dirs: forextentioninextensions: file_glob=os.path.join(sub_dir,'*.'+extention) files_list.extend(glob.glob(file_glob)) resnet50_feature_extractor=models.resnet50(pretrained=True) resnet50_feature_extractor.fc=nn.Linear(2048,2048) torch.nn.init.eye(resnet50_feature_extractor.fc.weight) forparaminresnet50_feature_extractor.parameters(): param.requires_grad=False use_gpu=torch.cuda.is_available() forx_pathinfiles_list: print(x_path) fx_path=os.path.join(features_dir,x_path[2:]+'.txt') extractor(x_path,fx_path,resnet50_feature_extractor,use_gpu)
另外最近发现一个很简单的提取不含FC层的网络的方法:
resnet=models.resnet152(pretrained=True) modules=list(resnet.children())[:-1]#deletethelastfclayer. convnet=nn.Sequential(*modules)
另一种更简单的方法:
resnet=models.resnet152(pretrained=True) delresnet.fc
以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。