pytorch实现线性拟合方式
一维线性拟合
数据为y=4x+5加上噪音
结果:
importnumpyasnp frommpl_toolkits.mplot3dimportAxes3D frommatplotlibimportpyplotasplt fromtorch.autogradimportVariable importtorch fromtorchimportnn X=torch.unsqueeze(torch.linspace(-1,1,100),dim=1) Y=4*X+5+torch.rand(X.size()) classLinearRegression(nn.Module): def__init__(self): super(LinearRegression,self).__init__() self.linear=nn.Linear(1,1)#输入和输出的维度都是1 defforward(self,X): out=self.linear(X) returnout model=LinearRegression() criterion=nn.MSELoss() optimizer=torch.optim.SGD(model.parameters(),lr=1e-2) num_epochs=1000 forepochinrange(num_epochs): inputs=Variable(X) target=Variable(Y) #向前传播 out=model(inputs) loss=criterion(out,target) #向后传播 optimizer.zero_grad()#注意每次迭代都需要清零 loss.backward() optimizer.step() if(epoch+1)%20==0: print('Epoch[{}/{}],loss:{:.6f}'.format(epoch+1,num_epochs,loss.item())) model.eval() predict=model(Variable(X)) predict=predict.data.numpy() plt.plot(X.numpy(),Y.numpy(),'ro',label='OriginalData') plt.plot(X.numpy(),predict,label='FittingLine') plt.show()
多维:
fromitertoolsimportcount importtorch importtorch.autograd importtorch.nn.functionalasF POLY_DEGREE=3 defmake_features(x): """Buildsfeaturesi.e.amatrixwithcolumns[x,x^2,x^3].""" x=x.unsqueeze(1) returntorch.cat([x**iforiinrange(1,POLY_DEGREE+1)],1) W_target=torch.randn(POLY_DEGREE,1) b_target=torch.randn(1) deff(x): returnx.mm(W_target)+b_target.item() defget_batch(batch_size=32): random=torch.randn(batch_size) x=make_features(random) y=f(x) returnx,y #Definemodel fc=torch.nn.Linear(W_target.size(0),1) batch_x,batch_y=get_batch() print(batch_x,batch_y) forbatch_idxincount(1): #Getdata #Resetgradients fc.zero_grad() #Forwardpass output=F.smooth_l1_loss(fc(batch_x),batch_y) loss=output.item() #Backwardpass output.backward() #Applygradients forparaminfc.parameters(): param.data.add_(-0.1*param.grad.data) #Stopcriterion ifloss<1e-3: break defpoly_desc(W,b): """Createsastringdescriptionofapolynomial.""" result='y=' fori,winenumerate(W): result+='{:+.2f}x^{}'.format(w,len(W)-i) result+='{:+.2f}'.format(b[0]) returnresult print('Loss:{:.6f}after{}batches'.format(loss,batch_idx)) print('==>Learnedfunction:\t'+poly_desc(fc.weight.view(-1),fc.bias)) print('==>Actualfunction:\t'+poly_desc(W_target.view(-1),b_target))
以上这篇pytorch实现线性拟合方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。