Java Tflite模型在Android(ml vision)和Python中提供不同的输出

Java Tflite模型在Android(ml vision)和Python中提供不同的输出,java,python,android,tensorflow,kotlin,Java,Python,Android,Tensorflow,Kotlin,我使用ML Vision api从FaceNet模型创建嵌入,然后比较两个嵌入之间的余弦距离。Android版本和Python版本的输出差别很大。Python版本的性能比android版本要好得多。问题是什么?我在两者中都使用FaceNet模型 我正在使用ML套件进行推断 我认为这可能是因为java读取图像的方式不同,因为android中生成的图像数组与python中相同的图像数组不同。因此,我在跟踪google文档时遇到了这个问题 在将图像馈送至分类器之前,图像已转换为浮点数组,如下所示: v

我使用ML Vision api从FaceNet模型创建嵌入,然后比较两个嵌入之间的余弦距离。Android版本和Python版本的输出差别很大。Python版本的性能比android版本要好得多。问题是什么?我在两者中都使用FaceNet模型

我正在使用ML套件进行推断


我认为这可能是因为java读取图像的方式不同,因为android中生成的图像数组与python中相同的图像数组不同。

因此,我在跟踪google文档时遇到了这个问题 在将图像馈送至分类器之前,图像已转换为浮点数组,如下所示:

val bitmap = Bitmap.createScaledBitmap(yourInputImage, 224, 224, true)

val batchNum = 0
val input = Array(1) { Array(224) { Array(224) { FloatArray(3) } } }
for (x in 0..223) {
    for (y in 0..223) {
        val pixel = bitmap.getPixel(x, y)
        // Normalize channel values to [-1.0, 1.0]. This requirement varies by
        // model. For example, some models might require values to be normalized
        // to the range [0.0, 1.0] instead.
        input[batchNum][x][y][0] = (Color.red(pixel) - 127) / 255.0f
        input[batchNum][x][y][1] = (Color.green(pixel) - 127) / 255.0f
        input[batchNum][x][y][2] = (Color.blue(pixel) - 127) / 255.0f
    }
}
然后,我逐一分析了每一步,发现像素的获取方式是错误的,并且与python执行所有操作的方式完全不同

然后我发现了这种方法,我用我的函数改变了它:

    private fun convertBitmapToByteBuffer(bitmap: Bitmap): ByteBuffer {
        val imgData = ByteBuffer.allocateDirect(4 * INPUT_SIZE * INPUT_SIZE * PIXEL_SIZE)
        imgData.order(ByteOrder.nativeOrder())
        val intValues = IntArray(INPUT_SIZE * INPUT_SIZE)


        imgData.rewind()
        bitmap.getPixels(intValues, 0, bitmap.width, 0, 0, bitmap.width, bitmap.height)
        // Convert the image to floating point.
        var pixel = 0
        for (i in 0 until INPUT_SIZE) {
            for (j in 0 until INPUT_SIZE) {
                val `val` = intValues[pixel++]
                imgData.putFloat(((`val`.shr(16) and 0xFF) - IMAGE_MEAN)/IMAGE_STD)
                imgData.putFloat(((`val`.shr(8) and 0xFF)- IMAGE_MEAN)/ IMAGE_STD)
                imgData.putFloat(((`val` and 0xFF) - IMAGE_MEAN)/IMAGE_STD)
            }
        }
        return imgData;
   }
成功了