Pytorch 实现focal_loss 多类别和二分类示例
我就废话不多说了,直接上代码吧!
importnumpyasnp importtorch importtorch.nnasnn importtorch.nn.functionalasF #支持多分类和二分类 classFocalLoss(nn.Module): """ ThisisaimplementationofFocalLosswithsmoothlabelcrossentropysupportedwhichisproposedin 'FocalLossforDenseObjectDetection.(https://arxiv.org/abs/1708.02002)' Focal_Loss=-1*alpha*(1-pt)^gamma*log(pt) :paramnum_class: :paramalpha:(tensor)3Dor4Dthescalarfactorforthiscriterion :paramgamma:(float,double)gamma>0reducestherelativelossforwell-classifiedexamples(p>0.5)puttingmore focusonhardmisclassifiedexample :paramsmooth:(float,double)smoothvaluewhencrossentropy :parambalance_index:(int)balanceclassindex,shouldbespecificwhenalphaisfloat :paramsize_average:(bool,optional)Bydefault,thelossesareaveragedovereachlosselementinthebatch. """ def__init__(self,num_class,alpha=None,gamma=2,balance_index=-1,smooth=None,size_average=True): super(FocalLoss,self).__init__() self.num_class=num_class self.alpha=alpha self.gamma=gamma self.smooth=smooth self.size_average=size_average ifself.alphaisNone: self.alpha=torch.ones(self.num_class,1) elifisinstance(self.alpha,(list,np.ndarray)): assertlen(self.alpha)==self.num_class self.alpha=torch.FloatTensor(alpha).view(self.num_class,1) self.alpha=self.alpha/self.alpha.sum() elifisinstance(self.alpha,float): alpha=torch.ones(self.num_class,1) alpha=alpha*(1-self.alpha) alpha[balance_index]=self.alpha self.alpha=alpha else: raiseTypeError('Notsupportalphatype') ifself.smoothisnotNone: ifself.smooth<0orself.smooth>1.0: raiseValueError('smoothvalueshouldbein[0,1]') defforward(self,input,target): logit=F.softmax(input,dim=1) iflogit.dim()>2: #N,C,d1,d2->N,C,m(m=d1*d2*...) logit=logit.view(logit.size(0),logit.size(1),-1) logit=logit.permute(0,2,1).contiguous() logit=logit.view(-1,logit.size(-1)) target=target.view(-1,1) #N=input.size(0) #alpha=torch.ones(N,self.num_class) #alpha=alpha*(1-self.alpha) #alpha=alpha.scatter_(1,target.long(),self.alpha) epsilon=1e-10 alpha=self.alpha ifalpha.device!=input.device: alpha=alpha.to(input.device) idx=target.cpu().long() one_hot_key=torch.FloatTensor(target.size(0),self.num_class).zero_() one_hot_key=one_hot_key.scatter_(1,idx,1) ifone_hot_key.device!=logit.device: one_hot_key=one_hot_key.to(logit.device) ifself.smooth: one_hot_key=torch.clamp( one_hot_key,self.smooth,1.0-self.smooth) pt=(one_hot_key*logit).sum(1)+epsilon logpt=pt.log() gamma=self.gamma alpha=alpha[idx] loss=-1*alpha*torch.pow((1-pt),gamma)*logpt ifself.size_average: loss=loss.mean() else: loss=loss.sum() returnloss classBCEFocalLoss(torch.nn.Module): """ 二分类的Focallossalpha固定 """ def__init__(self,gamma=2,alpha=0.25,reduction='elementwise_mean'): super().__init__() self.gamma=gamma self.alpha=alpha self.reduction=reduction defforward(self,_input,target): pt=torch.sigmoid(_input) alpha=self.alpha loss=-alpha*(1-pt)**self.gamma*target*torch.log(pt)-\ (1-alpha)*pt**self.gamma*(1-target)*torch.log(1-pt) ifself.reduction=='elementwise_mean': loss=torch.mean(loss) elifself.reduction=='sum': loss=torch.sum(loss) returnloss
以上这篇Pytorch实现focal_loss多类别和二分类示例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。
声明:本文内容来源于网络,版权归原作者所有,内容由互联网用户自发贡献自行上传,本网站不拥有所有权,未作人工编辑处理,也不承担相关法律责任。如果您发现有涉嫌版权的内容,欢迎发送邮件至:czq8825#qq.com(发邮件时,请将#更换为@)进行举报,并提供相关证据,一经查实,本站将立刻删除涉嫌侵权内容。