Keras实现支持masking的Flatten层代码
不知道为什么,我总是需要实现某种骚操作,而这种骚操作往往是Keras不支持的。例如,我有一个padding过的矩阵,那么它一定是带masking的,然后我想要把它Flatten,再输入到Dense层。然而Keras的Flatten层不支持masking。
Keras原本Flatten的实现
classFlatten(Layer): def__init__(self,**kwargs): super(Flatten,self).__init__(**kwargs) self.input_spec=InputSpec(min_ndim=3) defcompute_output_shape(self,input_shape): ifnotall(input_shape[1:]): raiseValueError('Theshapeoftheinputto"Flatten"' 'isnotfullydefined' '(got'+str(input_shape[1:])+'.' 'Makesuretopassacomplete"input_shape"' 'or"batch_input_shape"argumenttothefirst' 'layerinyourmodel.') return(input_shape[0],np.prod(input_shape[1:])) defcall(self,inputs): returnK.batch_flatten(inputs)
自定义支持masking的实现
事实上,Keras层的mask有时候是需要参与运算的,比如Dense之类的,有时候则只是做某种变换然后传递给后面的层。Flatten属于后者,因为mask总是与input有相同的shape,所以我们要做的就是在compute_mask函数里对mask也做flatten。
fromkerasimportbackendasK fromkeras.engine.topologyimportLayer importtensorflowastf importnumpyasnp classMyFlatten(Layer): def__init__(self,**kwargs): self.supports_masking=True super(MyFlatten,self).__init__(**kwargs) defcompute_mask(self,inputs,mask=None): ifmask==None: returnmask returnK.batch_flatten(mask) defcall(self,inputs,mask=None): returnK.batch_flatten(inputs) defcompute_output_shape(self,input_shape): return(input_shape[0],np.prod(input_shape[1:]))
正确性检验
fromkeras.layersimport* fromkeras.modelsimportModel fromMyFlattenimportMyFlatten fromMySumLayerimportMySumLayer fromkeras.initializersimportones data=[[1,0,0,0], [1,2,0,0], [1,2,3,0], [1,2,3,4]] A=Input(shape=[4])#None*4 emb=Embedding(5,3,mask_zero=True,embeddings_initializer=ones())(A)#None*4*3 fla=MyFlatten()(emb)#None*12 out=MySumLayer(axis=1)(fla)#None*1 model=Model(inputs=[A],outputs=[out]) printmodel.predict(data)
输出:
[3.6.9.12.]
补充知识:pytorch中的reshape()、view()、transpose()和flatten()
1、torch.reshape()
reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()调用
其作用是在不改变tensor元素数目的情况下改变tensor的shape
importtorch importnumpyasnp a=np.arange(24) b=a.reshape(4,3,2) print(np.shape(a)) print(b,np.shape(b)) '''结果 (24,) [[[01] [23] [45]] [[67] [89] [1011]] [[1213] [1415] [1617]] [[1819] [2021] [2223]]](4,3,2) '''
2、view()
view()只可以由torch.Tensor.view()来调用
view()和reshape()在效果上是一样的,区别是view()只能操作contiguous的tensor,且view后的tensor和原tensor共享存储,reshape()对于是否contiuous的tensor都可以操作。
3、transpose()
torch.transpose(input,dim0,dim1)->Tensor
将输入数据input的第dim0维和dim1维进行交换
#官方例子 >>>x=torch.randn(2,3) >>>x tensor([[0.9068,1.8803,-0.5021], [-0.6576,0.6334,-0.8961]]) >>>torch.transpose(x,0,1) tensor([[0.9068,-0.6576], [1.8803,0.6334], [-0.5021,-0.8961]])
4、flatten()
torch.flatten()的输入是tensor
torch.flatten(input,start_dim=0,end_dim=-1)→Tensor
其作用是将输入tensor的第start_dim维到end_dim维之间的数据“拉平”成一维tensor,
#官方例子 >>>t=torch.tensor([[[1,2], [3,4]], [[5,6], [7,8]]]) >>>torch.flatten(t) tensor([1,2,3,4,5,6,7,8]) >>>torch.flatten(t,start_dim=1) tensor([[1,2,3,4], [5,6,7,8]])
torch.nn.Flatten()可以理解为一种网络结构,类似Conv2d、Linear。一般放在卷积层和全连接层之间,将卷积层输出“拉平”成一维,
>>>m=torch.nn.Sequential( torch.nn.Conv2d(1,32,5,1,1), torch.nn.Flatten(), torch.nn.Linear(160,10)) >>>m Sequential( (0):Conv2d(1,32,kernel_size=(5,5),stride=(1,1),padding=(1,1)) (1):Flatten() (2):Linear(in_features=160,out_features=10,bias=True) )
以上这篇Keras实现支持masking的Flatten层代码就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持毛票票。