Javascript 无法理解tensorflow.js上的预测张量

Javascript 无法理解tensorflow.js上的预测张量,javascript,python,react-native,tensorflow,Javascript,Python,React Native,Tensorflow,最近,我使用Tensorflow在python中训练了一个图像分类模型。我将其保存为一个SavedModel,并将其转换为Tf.js模型,从而得到一个权重和一个json模型。我的目标是在react native上使用Tfjs,让用户能够输入他们的图片并进行测试 然而,在建立模型之后,预测返回的张量对我来说没有多大意义。我只有两种图像分类 这是因为我没有把数据输入预测井吗 这是我的源代码,感谢大家的帮助 import React, { useState, useEffect } from &quo

最近,我使用Tensorflow在python中训练了一个图像分类模型。我将其保存为一个SavedModel,并将其转换为Tf.js模型,从而得到一个权重和一个json模型。我的目标是在react native上使用Tfjs,让用户能够输入他们的图片并进行测试

然而,在建立模型之后,预测返回的张量对我来说没有多大意义。我只有两种图像分类

这是因为我没有把数据输入预测井吗

这是我的源代码,感谢大家的帮助

import React, { useState, useEffect } from "react";
import { StyleSheet, View, TouchableOpacity, Text } from "react-native";
import * as tf from "@tensorflow/tfjs";
import { fetch, bundleResourceIO } from "@tensorflow/tfjs-react-native";
import Constants from "expo-constants";
import * as Permissions from "expo-permissions";
import * as ImagePicker from "expo-image-picker";
import * as jpeg from "jpeg-js";
import * as FileSystem from "expo-file-system";
import Output from "../components/Output";

async function imageToTensor(source) {
  // load the raw data of the selected image into an array
  // const response = await fetch(source.uri, {}, { isBinary: true });
  // const rawImageData = await response.arrayBuffer();
  // console.log("here1")
  const imgB64 = await FileSystem.readAsStringAsync(source.uri, {
    encoding: FileSystem.EncodingType.Base64,
  });
  const rawImageData = tf.util.encodeString(imgB64, "base64").buffer;
  const { width, height, data } = jpeg.decode(rawImageData, {
    useTArray: true, // Uint8Array = true
  });
  // console.log("here2")
  // remove the alpha channel:
  const buffer = new Uint8Array(width * height * 3);
  let offset = 0;
  for (let i = 0; i < buffer.length; i += 3) {
    buffer[i] = data[offset];
    buffer[i + 1] = data[offset + 1];
    buffer[i + 2] = data[offset + 2];`
    offset += 4;
  }

  // transform image data into a tensor
  console.log(buffer);
  const img = tf.tensor3d(buffer, [width, height, 3]);

  // calculate square center crop area
  const shorterSide = Math.min(width, height);
  const startingHeight = (height - shorterSide) / 2;
  const startingWidth = (width - shorterSide) / 2;
  const endingHeight = startingHeight + shorterSide;
  const endingWidth = startingWidth + shorterSide;

  // slice and resize the image
  const sliced_img = img.slice(
    [startingWidth, startingHeight, 0],
    [endingWidth, endingHeight, 3]
  );
  const resized_img = tf.image.resizeBilinear(sliced_img, [400, 400]);

  // add a fourth batch dimension to the tensor
  const expanded_img = resized_img.expandDims(0);
  // normalise the rgb values to -1-+1
  return expanded_img.toFloat().div(tf.scalar(255));
}

const NeuralNetworkScreen = ({ navigation }) => {
  const [isTfReady, setTfReady] = useState(false); // gets and sets the Tensorflow.js module loading status
  const [model, setModel] = useState(null); // gets and sets the locally saved Tensorflow.js model
  const [image, setImage] = useState(null); // gets and sets the image selected from the user
  const [predictions, setPredictions] = useState(null); // gets and sets the predicted value from the model
  const [error, setError] = useState(false); // gets and sets any errors

  useEffect(() => {
    (async () => {
      await tf.ready(); // wait for Tensorflow.js to get ready
      setTfReady(true); // set the state

      // bundle the model files and load the model:
      const model = require("../assets/model.json");
      const weights = require("../assets/weights.bin");
      const loadedModel = await tf.loadGraphModel(
        bundleResourceIO(model, weights)
      );
      setModel(loadedModel); // load the model to the state
    })();
  }, []);

  async function handlerSelectImage() {
    try {
      let response = await ImagePicker.launchImageLibraryAsync({
        mediaTypes: ImagePicker.MediaTypeOptions.Images,
        allowsEditing: true, // on Android user can rotate and crop the selected image; iOS users can only crop
        quality: 1, // go for highest quality possible
        aspect: [4, 3], // maintain aspect ratio of the crop area on Android; on iOS crop area is always a square
      });

      if (!response.cancelled) {
        const source = { uri: response.uri };
        setImage(source); // put image path to the state
        const imageTensor = await imageToTensor(source); // prepare the image
        const predictions = await model.predict(imageTensor, {batchSize=1}); // send the image to the model
        predictions.data().then((data) => console.log(data));

        setPredictions(predictions); // put model prediction to the state
      }
    } catch (error) {
      setError(error);
    }
  }

  function reset() {
    setPredictions(null);
    setImage(null);
    setError(false);
  }

  let status, statusMessage, showReset;
  const resetLink = (
    <Text onPress={reset} style={styles.reset}>
      Restart
    </Text>
  );

  if (!error) {
    if (isTfReady && model && !image && !predictions) {
      status = "modelReady";
      statusMessage = "Model is ready.";
    } else if (model && image && predictions) {
      status = "finished";
      statusMessage = "Prediction finished.";
      showReset = true;
    } else if (model && image && !predictions) {
      status = "modelPredict";
      statusMessage = "Model is predicting...";
    } else {
      status = "modelLoad";
      statusMessage = "Model is loading...";
    }
  } else {
    statusMessage = "Unexpected error occured.";
    showReset = true;
    console.log(error);
  }

  return (
    <View style={styles.container}>
      <View style={styles.innercontainer}>
        <Text style={styles.status}>
          {statusMessage} {showReset ? resetLink : null}
        </Text>
        <TouchableOpacity
          style={styles.imageContainer}
          onPress={model && !predictions ? handlerSelectImage : () => {}} // Activates handler only if the model has been loaded and there are no predictions done yet
        >
          <Output
            status={status}
            image={image}
            predictions={predictions}
            error={error}
          />
        </TouchableOpacity>
      </View>
    </View>
  );
};
Tensor {
  "dataId": Object {},
  "dtype": "float32",
  "id": 54,
  "isDisposedInternal": false,
  "kept": false,
  "rankType": "2",
  "scopeId": 10,
  "shape": Array [
    1,
    5,
  ],
  "size": 5,
  "strides": Array [
    5,
  ],
}
9.361248016357422
-2.4212443828582764
-3.309708833694458
-2.687634229660034
-0.9071207046508789
And those are the 5 values from the tensor.