Python 从tensorflow中的csv文件预测值

Python 从tensorflow中的csv文件预测值,python,tensorflow,Python,Tensorflow,我创建了一个线性分类模型,并在硬盘上存储了一个检查点 在另一个python文件中,我希望恢复模型,并使用存储在CSV文件中的数据的模型预测值。我有以下代码: import tensorflow as tf import pandas as pd # file with data that must be predicted file = 'census_predict.csv' x_test = pd.read_csv(file) print(x_test) # Create some v

我创建了一个线性分类模型,并在硬盘上存储了一个检查点

在另一个python文件中,我希望恢复模型,并使用存储在CSV文件中的数据的模型预测值。我有以下代码:

import tensorflow as tf
import pandas as pd

# file with data that must be predicted
file = 'census_predict.csv'
x_test = pd.read_csv(file)

print(x_test)

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
    # Restore variables from disk.
    saver.restore(sess, "models/1/1.ckpt")
    print("Model restored.")


    pred_fn = tf.estimator.inputs.pandas_input_fn(x=x_test, batch_size=len(x_test), shuffle=False)
    predictions = list(sess.predict(input_fn=pred_fn))
    print(predictions)

# close TF session in case a new one must be created later on
tf.reset_default_graph()
sess.close()
现在我得到了以下错误:

AttributeError: 'Session' object has no attribute 'predict'

我一直在stackoverflow和其他网站上寻找,但似乎无法找到正确的方法。有人能给我指引正确的方向吗?

不确定,但错误意味着您的会话对象上没有预测方法。文档应该说明方法的位置或者拼写是否不同。我认为这是因为会话对象上没有预测方法。但我不知道在哪里使用预测方法。