Keras之自定义损失(loss)函数用法说明
在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:
defmy_loss(y_true,y_pred): #y_true:Truelabels.TensorFlow/Theanotensor #y_pred:Predictions.TensorFlow/Theanotensorofthesameshapeasy_true . . . returnscalar#返回一个标量值
然后在model.compile中指定即可,如:
model.compile(loss=my_loss,optimizer='sgd')
具体参考Keras官方metrics的定义keras/metrics.py:
"""Built-inmetrics. """ from__future__importabsolute_import from__future__importdivision from__future__importprint_function importsix from.importbackendasK from.lossesimportmean_squared_error from.lossesimportmean_absolute_error from.lossesimportmean_absolute_percentage_error from.lossesimportmean_squared_logarithmic_error from.lossesimporthinge from.lossesimportlogcosh from.lossesimportsquared_hinge from.lossesimportcategorical_crossentropy from.lossesimportsparse_categorical_crossentropy from.lossesimportbinary_crossentropy from.lossesimportkullback_leibler_divergence from.lossesimportpoisson from.lossesimportcosine_proximity from.utils.generic_utilsimportdeserialize_keras_object from.utils.generic_utilsimportserialize_keras_object defbinary_accuracy(y_true,y_pred): returnK.mean(K.equal(y_true,K.round(y_pred)),axis=-1) defcategorical_accuracy(y_true,y_pred): returnK.cast(K.equal(K.argmax(y_true,axis=-1), K.argmax(y_pred,axis=-1)), K.floatx()) defsparse_categorical_accuracy(y_true,y_pred): #reshapeincaseit'sinshape(num_samples,1)insteadof(num_samples,) ifK.ndim(y_true)==K.ndim(y_pred): y_true=K.squeeze(y_true,-1) #convertdensepredictionstolabels y_pred_labels=K.argmax(y_pred,axis=-1) y_pred_labels=K.cast(y_pred_labels,K.floatx()) returnK.cast(K.equal(y_true,y_pred_labels),K.floatx()) deftop_k_categorical_accuracy(y_true,y_pred,k=5): returnK.mean(K.in_top_k(y_pred,K.argmax(y_true,axis=-1),k),axis=-1) defsparse_top_k_categorical_accuracy(y_true,y_pred,k=5): #Iftheshapeofy_trueis(num_samples,1),flattento(num_samples,) returnK.mean(K.in_top_k(y_pred,K.cast(K.flatten(y_true),'int32'),k), axis=-1) #Aliases mse=MSE=mean_squared_error mae=MAE=mean_absolute_error mape=MAPE=mean_absolute_percentage_error msle=MSLE=mean_squared_logarithmic_error cosine=cosine_proximity defserialize(metric): returnserialize_keras_object(metric) defdeserialize(config,custom_objects=None): returndeserialize_keras_object(config, module_objects=globals(), custom_objects=custom_objects, printable_module_name='metricfunction') defget(identifier): ifisinstance(identifier,dict): config={'class_name':str(identifier),'config':{}} returndeserialize(config) elifisinstance(identifier,six.string_types): returndeserialize(str(identifier)) elifcallable(identifier): returnidentifier else: raiseValueError('Couldnotinterpret' 'metricfunctionidentifier:',identifier)
以上这篇Keras之自定义损失(loss)函数用法说明就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。