Pytorch实现的手写数字mnist识别功能完整示例
本文实例讲述了Pytorch实现的手写数字mnist识别功能。分享给大家供大家参考,具体如下:
importtorch importtorchvisionastv importtorchvision.transformsastransforms importtorch.nnasnn importtorch.optimasoptim importargparse #定义是否使用GPU device=torch.device("cuda"iftorch.cuda.is_available()else"cpu") #定义网络结构 classLeNet(nn.Module): def__init__(self): super(LeNet,self).__init__() self.conv1=nn.Sequential(#input_size=(1*28*28) nn.Conv2d(1,6,5,1,2),#padding=2保证输入输出尺寸相同 nn.ReLU(),#input_size=(6*28*28) nn.MaxPool2d(kernel_size=2,stride=2),#output_size=(6*14*14) ) self.conv2=nn.Sequential( nn.Conv2d(6,16,5), nn.ReLU(),#input_size=(16*10*10) nn.MaxPool2d(2,2)#output_size=(16*5*5) ) self.fc1=nn.Sequential( nn.Linear(16*5*5,120), nn.ReLU() ) self.fc2=nn.Sequential( nn.Linear(120,84), nn.ReLU() ) self.fc3=nn.Linear(84,10) #定义前向传播过程,输入为x defforward(self,x): x=self.conv1(x) x=self.conv2(x) #nn.Linear()的输入输出都是维度为一的值,所以要把多维度的tensor展平成一维 x=x.view(x.size()[0],-1) x=self.fc1(x) x=self.fc2(x) x=self.fc3(x) returnx #使得我们能够手动输入命令行参数,就是让风格变得和Linux命令行差不多 parser=argparse.ArgumentParser() parser.add_argument('--outf',default='./model/',help='foldertooutputimagesandmodelcheckpoints')#模型保存路径 parser.add_argument('--net',default='./model/net.pth',help="pathtonetG(tocontinuetraining)")#模型加载路径 opt=parser.parse_args() #超参数设置 EPOCH=8#遍历数据集次数 BATCH_SIZE=64#批处理尺寸(batch_size) LR=0.001#学习率 #定义数据预处理方式 transform=transforms.ToTensor() #定义训练数据集 trainset=tv.datasets.MNIST( root='./data/', train=True, download=True, transform=transform) #定义训练批处理数据 trainloader=torch.utils.data.DataLoader( trainset, batch_size=BATCH_SIZE, shuffle=True, ) #定义测试数据集 testset=tv.datasets.MNIST( root='./data/', train=False, download=True, transform=transform) #定义测试批处理数据 testloader=torch.utils.data.DataLoader( testset, batch_size=BATCH_SIZE, shuffle=False, ) #定义损失函数lossfunction和优化方式(采用SGD) net=LeNet().to(device) criterion=nn.CrossEntropyLoss()#交叉熵损失函数,通常用于多分类问题上 optimizer=optim.SGD(net.parameters(),lr=LR,momentum=0.9) #训练 if__name__=="__main__": forepochinrange(EPOCH): sum_loss=0.0 #数据读取 fori,datainenumerate(trainloader): inputs,labels=data inputs,labels=inputs.to(device),labels.to(device) #梯度清零 optimizer.zero_grad() #forward+backward outputs=net(inputs) loss=criterion(outputs,labels) loss.backward() optimizer.step() #每训练100个batch打印一次平均loss sum_loss+=loss.item() ifi%100==99: print('[%d,%d]loss:%.03f' %(epoch+1,i+1,sum_loss/100)) sum_loss=0.0 #每跑完一次epoch测试一下准确率 withtorch.no_grad(): correct=0 total=0 fordataintestloader: images,labels=data images,labels=images.to(device),labels.to(device) outputs=net(images) #取得分最高的那个类 _,predicted=torch.max(outputs.data,1) total+=labels.size(0) correct+=(predicted==labels).sum() print('第%d个epoch的识别准确率为:%d%%'%(epoch+1,(100*correct/total))) #torch.save(net.state_dict(),'%s/net_%03d.pth'%(opt.outf,epoch+1))
更多关于Python相关内容可查看本站专题:《Python数学运算技巧总结》、《Python图片操作技巧总结》、《Python数据结构与算法教程》、《Python函数使用技巧总结》、《Python字符串操作技巧汇总》及《Python入门与进阶经典教程》
希望本文所述对大家Python程序设计有所帮助。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。