Tensorflow 为什么使用Google AutoML的浏览器中的结果与我的android应用程序中导出的tflite文件如此不同?

Tensorflow 为什么使用Google AutoML的浏览器中的结果与我的android应用程序中导出的tflite文件如此不同?,tensorflow,deep-learning,google-cloud-automl,tensorflow-lite,Tensorflow,Deep Learning,Google Cloud Automl,Tensorflow Lite,我正在用Google AutoML训练一个模型。然后我导出了一个TFLite模型,在android应用程序上运行 在应用程序和浏览器中使用了相同的图像(Google AutoML让您有机会在浏览器中测试您的模型),我在浏览器中得到了非常准确的答案,但在我的应用程序中,我得到了不太准确的答案 有没有想过为什么会发生这种情况 相关代码: public void predict(String imagePath, final Promise promise) { imgData.order

我正在用Google AutoML训练一个模型。然后我导出了一个TFLite模型,在android应用程序上运行

在应用程序和浏览器中使用了相同的图像(Google AutoML让您有机会在浏览器中测试您的模型),我在浏览器中得到了非常准确的答案,但在我的应用程序中,我得到了不太准确的答案

有没有想过为什么会发生这种情况

相关代码:

  public void predict(String imagePath, final Promise promise) {
    imgData.order(ByteOrder.nativeOrder());
    labelProbArray = new byte[1][RESULTS_TO_SHOW];

    try {

      labelList = loadLabelList();
    } catch (Exception ex) {
      ex.printStackTrace();
    }
    Log.w("FIND_ ", imagePath);

    Bitmap bitmap = BitmapFactory.decodeFile(imagePath);
    convertBitmapToByteBuffer(bitmap);


    try {
      tflite = new Interpreter(loadModelFile());
    } catch (Exception ex) {
      Log.w("FIND_exception in loading tflite", "1");      
    }

    tflite.run(imgData, labelProbArray);
    promise.resolve(getResult());
  }


  private WritableNativeArray getResult() {
    WritableNativeArray result = new WritableNativeArray();

    //ArrayList<JSONObject> result = new ArrayList<JSONObject>();
    try {

      for (int i = 0; i < RESULTS_TO_SHOW; i++) {
          WritableNativeMap map = new WritableNativeMap();

          map.putString("label", Integer.toString(i));
          float output = (float)(labelProbArray[0][i] & 0xFF) / 255f;
          map.putString("prob", String.valueOf(output));
          result.pushMap(map);

          Log.w("FIND_label ", Integer.toString(i));
          Log.w("FIND_prob ", String.valueOf(labelProbArray[0][i]));
      }
    } catch (Exception ex) {
      ex.printStackTrace();
    }

    return result;
  }

  private List<String> loadLabelList() throws IOException {
    Activity activity = getCurrentActivity();
      List<String> labelList = new ArrayList<String>();
      BufferedReader reader =
              new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH)));
      String line;
      while ((line = reader.readLine()) != null) {
          labelList.add(line);
      }
      reader.close();
      return labelList;
  }

  private void convertBitmapToByteBuffer(Bitmap bitmap) {
      if (imgData == null) {
          return;
      }
      imgData.rewind();
      Log.w("FIND_bitmap width ", String.valueOf(bitmap.getWidth()));
      Log.w("FIND_bitmap height ", String.valueOf(bitmap.getHeight()));

      bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
      // Convert the image to floating point.
      int pixel = 0;
      //long startTime = SystemClock.uptimeMillis();
      for (int i = 0; i < DIM_IMG_SIZE_X; ++i) {
          for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) {
              final int val = intValues[pixel++];

              imgData.put((byte) ( ((val >> 16) & 0xFF)));
              imgData.put((byte) ( ((val >> 8) & 0xFF)));
              imgData.put((byte) ( (val & 0xFF)));
              Log.w("FIND_i", String.valueOf(i));
              Log.w("FIND_j", String.valueOf(j));
          }
      }
  }

  private MappedByteBuffer loadModelFile() throws IOException {
    Activity activity = getCurrentActivity();
    AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH);
    FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
    FileChannel fileChannel = inputStream.getChannel();
    long startOffset = fileDescriptor.getStartOffset();
    long declaredLength = fileDescriptor.getDeclaredLength();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  }
public void predict(字符串imagePath,最终承诺){
imgData.order(ByteOrder.nativeOrder());
labelProbArray=新字节[1][RESULTS_TO_SHOW];
试一试{
labelList=loadLabelList();
}捕获(例外情况除外){
例如printStackTrace();
}
Log.w(“查找”,imagePath);
位图位图=BitmapFactory.decodeFile(imagePath);
convertBitmapToByteBuffer(位图);
试一试{
tflite=新的解释器(loadModelFile());
}捕获(例外情况除外){
w(“在加载tflite时发现异常”,“1”);
}
tflite.run(imgData,labelProbArray);
promise.resolve(getResult());
}
私有可写NativeArray getResult(){
WritableNativeArray结果=新的WritableNativeArray();
//ArrayList结果=新建ArrayList();
试一试{
for(int i=0;i>16)和0xFF));
imgData.put((字节)((val>>8)和0xFF));
imgData.put((字节)((val&0xFF));
Log.w(“FIND_i”,String.valueOf(i));
Log.w(“FIND_j”,String.valueOf(j));
}
}
}
私有MappedByteBuffer loadModelFile()引发IOException{
活动活动=getCurrentActivity();
AssetFileDescriptor fileDescriptor=activity.getAssets().openFd(模型路径);
FileInputStream inputStream=新的FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel FileChannel=inputStream.getChannel();
long startOffset=fileDescriptor.getStartOffset();
long declaredLength=fileDescriptor.getDeclaredLength();
返回fileChannel.map(fileChannel.MapMode.READ_ONLY,startOffset,declaredLength);
}
我不知道是否正确的一件事是如何将输出数据调整为概率。在
getResult()
方法中,我编写了以下代码:

float输出=(float)(labelProbArray[0][i]&0xFF)/255f

为了得到一个介于0和1之间的数字,我除以255,但我不知道这是否正确


如果有帮助的话,.tflite模型是一个量化的uint8模型

你能分享一下你在Android上推理之前用来处理图像的代码吗?@ShubhamPanchal我刚刚发布了一些我用来做预测的相关方法。这些方法包括加载模型、处理图像和获得结果。让我知道我是否应该添加更多细节。您是否通过将像素值除以255来规范化像素值?您的意思是在convertBitmapToByteBuffer方法中?在这里我说“imgData.put((字节)((val>>16)和0xFF));”?我没有正常化。我只是在写无符号像素值。需要除以255吗?在大多数模型中,像素值是标准化的,这样梯度在训练时不会爆炸。您对哪种型号的AutoML进行过培训