Tensorflow Lite预训练模型在Android演示中不起作用

Tensorflow Lite预训练模型在Android演示中不起作用,android,tensorflow,tensorflow-lite,Android,Tensorflow,Tensorflow Lite,Tensorflow Lite Android演示程序使用它提供的原始模型:mobilenet_quant_v1_224.tflite。见: 他们还提供了其他预训练的lite模型: 但是,我从上面的链接下载了一些较小的模型,例如mobilenet_v1_0.25_224.tflite,并在演示应用程序中用此模型替换了原始模型,只需更改model_PATH=“mobilenet_v1_0.25_224.tflite”在ImageClassifier.java中。应用程序因以下原因崩溃: 12-11

Tensorflow Lite Android演示程序使用它提供的原始模型:mobilenet_quant_v1_224.tflite。见:

他们还提供了其他预训练的lite模型:

但是,我从上面的链接下载了一些较小的模型,例如mobilenet_v1_0.25_224.tflite,并在演示应用程序中用此模型替换了原始模型,只需更改
model_PATH=“mobilenet_v1_0.25_224.tflite”
ImageClassifier.java
中。应用程序因以下原因崩溃:

12-11 12:52:34.222 17713-17729/?E/AndroidRuntime:致命异常: 摄影背景 进程:android.example.com.tflitecamerademo,PID:17713 java.lang.IllegalArgumentException:无法获取输入维度。 第0个输入应该有602112个字节,但找到了150528个字节。 位于org.tensorflow.lite.NativeInterpreterWrapper.getInputDims(Native (方法) 位于org.tensorflow.lite.NativeInterpreterWrapper.run(NativeInterpreterWrapper.java:82) 位于org.tensorflow.lite.explorer.runForMultipleInputsOutputs(explorer.java:112) 位于org.tensorflow.lite.explorer.run(explorer.java:93) 位于com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108) 位于com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663) 在com.example.android.tflitecamerademo.Camera2BasicFragment.access$900(Camera2BasicFragment.java:69) 位于com.example.android.tflitecamerademo.Camera2BasicFragment$5.run(Camera2BasicFragment.java:558) 位于android.os.Handler.handleCallback(Handler.java:751) 位于android.os.Handler.dispatchMessage(Handler.java:95) 位于android.os.Looper.loop(Looper.java:154) 运行(HandlerThread.java:61)

原因似乎是模型所需的输入尺寸比图像尺寸大四倍。因此,我将
DIM\u BATCH\u SIZE=1
修改为
DIM\u BATCH\u SIZE=4
。现在的错误是:

致命异常:CameraBackground 进程:android.example.com.tflitecamerademo,PID:18241 java.lang.IllegalArgumentException:无法转换TensorFlowLite 将FLOAT32类型的张量转换为[[B]类型的Java对象(即 与TensorFlowLite类型UINT8兼容 位于org.tensorflow.lite.Tensor.copyTo(Tensor.java:36) 位于org.tensorflow.lite.explorer.runForMultipleInputsOutputs(explorer.java:122) 位于org.tensorflow.lite.explorer.run(explorer.java:93) 位于com.example.android.tflitecamerademo.ImageClassifier.classifyFrame(ImageClassifier.java:108) 位于com.example.android.tflitecamerademo.Camera2BasicFragment.classifyFrame(Camera2BasicFragment.java:663) 在com.example.android.tflitecamerademo.Camera2BasicFragment.access$900(Camera2BasicFragment.java:69) 位于com.example.android.tflitecamerademo.Camera2BasicFragment$5.run(Camera2BasicFragment.java:558) 位于android.os.Handler.handleCallback(Handler.java:751) 位于android.os.Handler.dispatchMessage(Handler.java:95) 位于android.os.Looper.loop(Looper.java:154) 运行(HandlerThread.java:61)

我的问题是如何让简化的MobileNet tflite模型与TF lite Android演示版配合使用。


(实际上,我还尝试了其他方法,比如使用提供的工具将TF冻结图转换为TF lite模型,即使使用与中完全相同的示例代码,但转换后的tflite模型仍然无法在Android演示中工作。)

Tensorflow Lite Android演示中包含的ImageClassifier.java需要一个量化的模型。截至目前,只有一个Mobilenet模型以量化形式提供:Mobilenet 1.0 224 Quant

要使用其他浮点模型,请从Tensorflow for Lite演示源代码中交换ImageClassifier.java。这是为浮点模型编写的。

做一个比较,你会发现在实现上有几个重要的区别

考虑的另一个选项是使用浮点转换将浮点模型转换为量化:


我也遇到了和苗木一样的错误。 我已经为Mobilenet Float模型创建了一个新的图像分类器包装器。 现在工作正常。您可以直接在图像分类器演示中添加这个类,并使用它在Camera2BasicFragment中创建分类器

classifier = new ImageClassifierFloatMobileNet(getActivity());
下面是Mobilenet浮点模型的图像分类器类包装

    /**
 * This classifier works with the Float MobileNet model.
 */
public class ImageClassifierFloatMobileNet extends ImageClassifier {

  /**
   * An array to hold inference results, to be feed into Tensorflow Lite as outputs.
   * This isn't part of the super class, because we need a primitive array here.
   */
  private float[][] labelProbArray = null;

  private static final int IMAGE_MEAN = 128;
  private static final float IMAGE_STD = 128.0f;

  /**
   * Initializes an {@code ImageClassifier}.
   *
   * @param activity
   */
  public ImageClassifierFloatMobileNet(Activity activity) throws IOException {
    super(activity);
    labelProbArray = new float[1][getNumLabels()];
  }

  @Override
  protected String getModelPath() {
    // you can download this file from
    // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip
//    return "mobilenet_quant_v1_224.tflite";
    return "retrained.tflite";
  }

  @Override
  protected String getLabelPath() {
//    return "labels_mobilenet_quant_v1_224.txt";
    return "retrained_labels.txt";
  }

  @Override
  public int getImageSizeX() {
    return 224;
  }

  @Override
  public int getImageSizeY() {
    return 224;
  }

  @Override
  protected int getNumBytesPerChannel() {
    // the Float model uses a 4 bytes
    return 4;
  }

  @Override
  protected void addPixelValue(int val) {
    imgData.putFloat((((val >> 16) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
    imgData.putFloat((((val >> 8) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
    imgData.putFloat((((val) & 0xFF)-IMAGE_MEAN)/IMAGE_STD);
  }

  @Override
  protected float getProbability(int labelIndex) {
    return labelProbArray[0][labelIndex];
  }

  @Override
  protected void setProbability(int labelIndex, Number value) {
    labelProbArray[0][labelIndex] = value.byteValue();
  }

  @Override
  protected float getNormalizedProbability(int labelIndex) {
    return labelProbArray[0][labelIndex];
  }

  @Override
  protected void runInference() {
    tflite.run(imgData, labelProbArray);
  }
}

你能不能在文章的正文中提出一个明确的问题(不仅仅是标题)?请看一看。