Keras带掩蔽的平均层

Keras带掩蔽的平均层,keras,merge,average,masking,Keras,Merge,Average,Masking,Keras附带的Average layer已经支持掩蔽,但是,查看Average layer的源代码,我不清楚如何以及是否应用了掩蔽 我有一个输入列表,每个都有自己的屏蔽(例如,来自嵌入层)。我想要的平均层应该取那些没有被屏蔽的输入的平均值。换句话说,如果一个输入被屏蔽了,它就不应该对计算出的平均值有任何发言权。如果所有输入都被屏蔽,则输出被屏蔽并传递到下一层 一个相关的问题是,库附带的平均层只支持输入列表的合并函数。是否有库支持沿特定维度合并张量?是否可以将张量切片到输入列表中,以输入到平均层

Keras附带的Average layer已经支持掩蔽,但是,查看Average layer的源代码,我不清楚如何以及是否应用了掩蔽

我有一个输入列表,每个都有自己的屏蔽(例如,来自嵌入层)。我想要的平均层应该取那些没有被屏蔽的输入的平均值。换句话说,如果一个输入被屏蔽了,它就不应该对计算出的平均值有任何发言权。如果所有输入都被屏蔽,则输出被屏蔽并传递到下一层

一个相关的问题是,库附带的平均层只支持输入列表的合并函数。是否有库支持沿特定维度合并张量?是否可以将张量切片到输入列表中,以输入到平均层?如果没有,如何在存在掩蔽的情况下沿某个维度取张量的平均值

我倾向于编写一个自定义的平均层来计算掩码,并在计算输出时使用掩码,但是从文档来看,不清楚如何做到这一点


非常感谢任何指针或代码示例。

如果您查看平均层的源代码,它实际上是“\u Merge”层的子类,因为平均层不会覆盖“compute\u mask”函数,因此它将继承“\u Merge”层的“compute\u mask”函数,如下所示:

  def compute_mask(self, inputs, mask=None):
    if mask is None:
      return None
    if not isinstance(mask, list):
      raise ValueError('`mask` should be a list.')
    if not isinstance(inputs, list):
      raise ValueError('`inputs` should be a list.')
    if len(mask) != len(inputs):
      raise ValueError('The lists `inputs` and `mask` '
                       'should have the same length.')
    if all(m is None for m in mask):
      return None
    masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
    return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)

最后4行表示:如果所有输入掩码均为None,则返回None。否则,输出掩码首先连接所有非无的掩码,然后执行“all”操作,这意味着如果其中一个输入掩码被屏蔽(False),则结果掩码被屏蔽(False),并且仅当所有输入掩码未被屏蔽(True)时,结果掩码才为True(非屏蔽).

如果您查看平均层的源代码,它实际上是从“\u Merge”层派生的子类,因为平均层不会覆盖“compute\u mask”函数,因此它将继承“\u Merge”层的“compute\u mask”函数,如下所示:

  def compute_mask(self, inputs, mask=None):
    if mask is None:
      return None
    if not isinstance(mask, list):
      raise ValueError('`mask` should be a list.')
    if not isinstance(inputs, list):
      raise ValueError('`inputs` should be a list.')
    if len(mask) != len(inputs):
      raise ValueError('The lists `inputs` and `mask` '
                       'should have the same length.')
    if all(m is None for m in mask):
      return None
    masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
    return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
最后4行表示:如果所有输入掩码均为None,则返回None。否则,输出掩码首先连接所有非无的掩码,然后执行“all”操作,这意味着如果其中一个输入掩码被屏蔽(False),则结果掩码被屏蔽(False),并且仅当所有输入掩码未被屏蔽(True)时,结果掩码才为True(非屏蔽)