Tensorflow map_fn给出错误:ValueError:没有名为'的属性_XlaCompile';

Tensorflow map_fn给出错误:ValueError:没有名为'的属性_XlaCompile';,tensorflow,Tensorflow,我尝试实现中描述的“批量硬”批处理,以使用三重丢失。因此,输入是[batch_size,32]的形状,输出应该是一个表示三元组的列表,所以[[batch_size,32],[batch_size,32],[batch_size,32]],当每个示例的大小为(32,)时 我使用以下函数实现了这一点,因此基本上使用tf.map_fn: def batch_hard(inputs): """ Batch Hard triplets as described in https://ar

我尝试实现中描述的“批量硬”批处理,以使用三重丢失。因此,输入是[batch_size,32]的形状,输出应该是一个表示三元组的列表,所以[[batch_size,32],[batch_size,32],[batch_size,32]],当每个示例的大小为(32,)时

我使用以下函数实现了这一点,因此基本上使用tf.map_fn:

def batch_hard(inputs):
    """ 
    Batch Hard triplets as described in https://arxiv.org/pdf/1703.07737.pdf.
    For each sample in input the hardest positive and hardest negative
    in the given batch will be selected. A triplet is returned.
    """
    class_ids, f_anchor = inputs[0], inputs[1]

    def body(x):
        class_id, f = x[0], x[1]

        same_class = tf.equal(class_ids, class_id)

        positive = same_class
        negative = tf.logical_not(same_class)

        positive = tf.squeeze(positive)
        negative = tf.squeeze(negative)

        positive.set_shape([None])
        negative.set_shape([None])

        samples_pos = tf.boolean_mask(f_anchor, positive)
        samples_neg = tf.boolean_mask(f_anchor, negative)

        # Select hardest positive example
        distances = euclidean_distance(samples_pos, f)
        hardest_pos = samples_pos[tf.argmax(distances)]

        # Select hardest negative example
        distances = euclidean_distance(samples_neg, f)
        hardest_neg = samples_neg[tf.argmin(distances)]

        return [hardest_pos, hardest_neg]

    [f_pos, f_neg] = tf.map_fn(body, inputs, dtype=[tf.float32, tf.float32])
    return [f_anchor, f_pos, f_neg]
当我只进行向前传球,没有指定训练时,这种方法非常有效。但是,当我添加这一行时,会出现以下错误:

Traceback (most recent call last):
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/ops/gradients_impl.py", line 348, in _MaybeCompile
    xla_compile = op.get_attr("_XlaCompile")
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/ops.py", line 2003, in get_attr
    raise ValueError("No attr named '" + name + "' in " + str(self._node_def))
ValueError: No attr named '_XlaCompile' in name: "map/while/strided_slice"
op: "StridedSlice"
input: "map/while/boolean_mask/Gather"
input: "map/while/strided_slice/stack"
input: "map/while/strided_slice/stack_1"
input: "map/while/strided_slice/Cast"
attr {
  key: "Index"
  value {
    type: DT_INT64
  }
}
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "begin_mask"
  value {
    i: 0
  }
}
attr {
  key: "ellipsis_mask"
  value {
    i: 0
  }
}
attr {
  key: "end_mask"
  value {
    i: 0
  }
}
attr {
  key: "new_axis_mask"
  value {
    i: 0
  }
}
attr {
  key: "shrink_axis_mask"
  value {
    i: 1
  }
}
有人知道哪里出了问题吗

这里有一个完整的例子

编辑:问题似乎就在这几行

hardest_pos = samples_pos[tf.argmax(distances)]
用类似

hardest_pos = tf.zeros(32)

没有错误,但是如何解决这个问题?

我认为您的问题可能与argmax不可微这一事实有关。因此,您只有在尝试优化网络时才会观察到此错误。