Python 如何使用tensorflow';s ncf模型预测?

Python 如何使用tensorflow';s ncf模型预测?,python,tensorflow,neural-network,recommendation-engine,tensorflow-model-garden,Python,Tensorflow,Neural Network,Recommendation Engine,Tensorflow Model Garden,嗨,我是tensorflow和神经网络的新手。试图理解tensorflow官方回购模型中的风险 我的理解是,您构建了一个包含输入层和学习层的模型。然后创建批量数据来训练模型,然后使用测试数据来评估模型。这是在这种情况下完成的 但是,我很难理解输入层 它显示在代码中 user_input = tf.keras.layers.Input( shape=(1,), name=movielens.USER_COLUMN, dtype=tf.int32) 据我所知,您可以一次输入一个参数 但

嗨,我是tensorflow和神经网络的新手。试图理解tensorflow官方回购模型中的风险

我的理解是,您构建了一个包含输入层和学习层的模型。然后创建批量数据来训练模型,然后使用测试数据来评估模型。这是在这种情况下完成的

但是,我很难理解输入层

它显示在代码中

user_input = tf.keras.layers.Input(
      shape=(1,), name=movielens.USER_COLUMN, dtype=tf.int32)
据我所知,您可以一次输入一个参数

但是,我只能使用以下虚拟数据来调用predict\u on\u batch

user_input = np.full(shape=(256,),fill_value=1, dtype=np.int32)
item_input = np.full(shape=(256,),fill_value=1, dtype=np.int32)
valid_pt_mask_input = np.full(shape=(256,),fill_value=True, dtype=np.bool)
dup_mask_input = np.full(shape=(256,),fill_value=1, dtype=np.int32)
label_input = np.full(shape=(256,),fill_value=True, dtype=np.bool)
test_input_list = [user_input,item_input,valid_pt_mask_input,dup_mask_input,label_input]

tf.print(keras_model.predict_on_batch(test_input_list))
当我运行以下代码时:

    user_input = np.full(shape=(1,),fill_value=1, dtype=np.int32)
    item_input = np.full(shape=(1,),fill_value=1, dtype=np.int32)
    valid_pt_mask_input = np.full(shape=(1,),fill_value=True, dtype=np.bool)
    dup_mask_input = np.full(shape=(1,),fill_value=1, dtype=np.int32)
    label_input = np.full(shape=(1,),fill_value=True, dtype=np.bool)
    test_input_list = [user_input,item_input,valid_pt_mask_input,dup_mask_input,label_input]

    classes = _model.predict(test_input_list)
    tf.print(classes)
我得到了这个错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:  Input to reshape is a tensor with 1 values, but the requested shape requires a multiple of 256
     [[{{node model_1/metric_layer/StatefulPartitionedCall/StatefulPartitionedCall/Reshape_1}}]] [Op:__inference_predict_function_2828]
有人能帮我用这个模型预测单输入吗?
还有,为什么在进行预测时需要使用用户id的item_id?你不应该提供一个用户列表吗?模型返回一个项目列表

我以前没有使用过ncf模型,但看起来您输入的训练数据是一个样本,包含256个特征,而不是256个样本,每个样本包含一个特征。只需翻转numpy阵列,确保特征矩阵是二维的,并且特征的数量是第一个维度

user_input = np.full(shape=(1,256),fill_value=1, dtype=np.int32)
…等等。(好的,标签应该保持1D,因为你有它们)

同样,确保预测输入中的特征矩阵为2D:

user_input = np.full(shape=(1,1),fill_value=1, dtype=np.int32)

我以前没有使用过ncf模型,但看起来您输入的训练数据是一个样本,包含256个特征,而不是256个样本,每个样本包含一个特征。只需翻转numpy阵列,确保特征矩阵是二维的,并且特征的数量是第一个维度

user_input = np.full(shape=(1,256),fill_value=1, dtype=np.int32)
…等等。(好的,标签应该保持1D,因为你有它们)

同样,确保预测输入中的特征矩阵为2D:

user_input = np.full(shape=(1,1),fill_value=1, dtype=np.int32)

如果你是TensorFlow和深度学习的新手,这个推荐项目可能不是一个好的开始。代码没有文档化,架构可能会让人感到困惑

无论如何,为了回答您的问题,该模型不采用单一输入进行预测。查看代码,有5个输入(用户id、项目id、重复掩码、有效掩码、标签),但这实际上是为了培训。如果您只想进行预测,实际上只需要用户id和项目id。此模型基于用户id和项目id交互进行预测,这就是您需要两者的原因。但是,除非在进行预测时切掉模型中不必要的部分,否则不能直接这样做。下面是关于如何在名为
keras_model
的模型对象经过培训后执行此操作的代码(我使用了tf model官方版本2.5.0,它运行良好):

因此,如果进行预测,则需要创建所有项目用户组合,然后按降序对预测进行排序,同时跟踪每个用户的索引。该用户的第一项是该用户最有可能与之交互的模型预测,第二项是第二个最有可能与之交互的模型预测,依此类推

在运行这个ncf_keras_main.py文件时,会创建一个名为“summaries”的文件夹。如果将tensorboard指向该文件夹,则可以在左上角的“图形”选项卡下浏览模型体系结构。这可能有助于更好地理解代码。要运行tensorboard,请打开一个终端并键入

tensorboard --logdir location_of_summaries_folder_here

如果您是TensorFlow和深度学习的新手,这个推荐项目可能不是一个好的起点。代码没有文档化,架构可能会让人感到困惑

无论如何,为了回答您的问题,该模型不采用单一输入进行预测。查看代码,有5个输入(用户id、项目id、重复掩码、有效掩码、标签),但这实际上是为了培训。如果您只想进行预测,实际上只需要用户id和项目id。此模型基于用户id和项目id交互进行预测,这就是您需要两者的原因。但是,除非在进行预测时切掉模型中不必要的部分,否则不能直接这样做。下面是关于如何在名为
keras_model
的模型对象经过培训后执行此操作的代码(我使用了tf model官方版本2.5.0,它运行良好):

因此,如果进行预测,则需要创建所有项目用户组合,然后按降序对预测进行排序,同时跟踪每个用户的索引。该用户的第一项是该用户最有可能与之交互的模型预测,第二项是第二个最有可能与之交互的模型预测,依此类推

在运行这个ncf_keras_main.py文件时,会创建一个名为“summaries”的文件夹。如果将tensorboard指向该文件夹,则可以在左上角的“图形”选项卡下浏览模型体系结构。这可能有助于更好地理解代码。要运行tensorboard,请打开一个终端并键入

tensorboard --logdir location_of_summaries_folder_here

感谢您的快速回复。我尝试了你的两个建议,但仍然有错误,这一次对这两个都是。以上不是一个完整的解决方案,而是如何继续的建议。我希望您必须对几个输入矩阵进行修改。谢谢,我将继续尝试修改。谢谢您的快速响应。我尝试了你的两个建议,但仍然有错误,这一次对这两个都是。以上不是一个完整的解决方案,而是如何继续的建议。我希望您必须对几个输入矩阵进行修改。谢谢,我将继续尝试修改。我也很难理解如何对它们的模型进行预测。如果你有结论的话,我会对你的结论非常感兴趣。老实说,我认为代码实现得很差。我发现它毫无理由地令人费解。嗨@aaaahaaaa我