pytorch实现用Resnet提取特征并保存为txt文件的方法
接触pytorch一天,发现pytorch上手的确比TensorFlow更快。可以更方便地实现用预训练的网络提特征。
以下是提取一张jpg图像的特征的程序:
#-*-coding:utf-8-*- importos.path importtorch importtorch.nnasnn fromtorchvisionimportmodels,transforms fromtorch.autogradimportVariable importnumpyasnp fromPILimportImage features_dir='./features' img_path="hymenoptera_data/train/ants/0013035.jpg" file_name=img_path.split('/')[-1] feature_path=os.path.join(features_dir,file_name+'.txt') transform1=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()] ) img=Image.open(img_path) img1=transform1(img) #resnet18=models.resnet18(pretrained=True) resnet50_feature_extractor=models.resnet50(pretrained=True) resnet50_feature_extractor.fc=nn.Linear(2048,2048) torch.nn.init.eye(resnet50_feature_extractor.fc.weight) forparaminresnet50_feature_extractor.parameters(): param.requires_grad=False #resnet152=models.resnet152(pretrained=True) #densenet201=models.densenet201(pretrained=True) x=Variable(torch.unsqueeze(img1,dim=0).float(),requires_grad=False) #y1=resnet18(x) y=resnet50_feature_extractor(x) y=y.data.numpy() np.savetxt(feature_path,y,delimiter=',') #y3=resnet152(x) #y4=densenet201(x) y_=np.loadtxt(feature_path,delimiter=',').reshape(1,2048)
以下是提取一个文件夹下所有jpg、jpeg图像的程序:
#-*-coding:utf-8-*- importos,torch,glob importnumpyasnp fromtorch.autogradimportVariable fromPILimportImage fromtorchvisionimportmodels,transforms importtorch.nnasnn importshutil data_dir='./hymenoptera_data' features_dir='./features' shutil.copytree(data_dir,os.path.join(features_dir,data_dir[2:])) defextractor(img_path,saved_path,net,use_gpu): transform=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()] ) img=Image.open(img_path) img=transform(img) x=Variable(torch.unsqueeze(img,dim=0).float(),requires_grad=False) ifuse_gpu: x=x.cuda() net=net.cuda() y=net(x).cpu() y=y.data.numpy() np.savetxt(saved_path,y,delimiter=',') if__name__=='__main__': extensions=['jpg','jpeg','JPG','JPEG'] files_list=[] sub_dirs=[x[0]forxinos.walk(data_dir)] sub_dirs=sub_dirs[1:] forsub_dirinsub_dirs: forextentioninextensions: file_glob=os.path.join(sub_dir,'*.'+extention) files_list.extend(glob.glob(file_glob)) resnet50_feature_extractor=models.resnet50(pretrained=True) resnet50_feature_extractor.fc=nn.Linear(2048,2048) torch.nn.init.eye(resnet50_feature_extractor.fc.weight) forparaminresnet50_feature_extractor.parameters(): param.requires_grad=False use_gpu=torch.cuda.is_available() forx_pathinfiles_list: print(x_path) fx_path=os.path.join(features_dir,x_path[2:]+'.txt') extractor(x_path,fx_path,resnet50_feature_extractor,use_gpu)
另外最近发现一个很简单的提取不含FC层的网络的方法:
resnet=models.resnet152(pretrained=True) modules=list(resnet.children())[:-1]#deletethelastfclayer. convnet=nn.Sequential(*modules)
另一种更简单的方法:
resnet=models.resnet152(pretrained=True) delresnet.fc
以上这篇pytorch实现用Resnet提取特征并保存为txt文件的方法就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。