Python 一直在尝试使这个tensorflow脚本工作

Python 一直在尝试使这个tensorflow脚本工作,python,tensorflow,Python,Tensorflow,我正在尝试将我的scikit学习python脚本移动到tensorflow代码中。不断地被错误缠住。请帮忙 import pandas as pd import numpy as np import tensorflow as tf # read csv df = pd.read_csv("/Downloads/iris-2.csv", header=0) # get header names as array features = l

我正在尝试将我的scikit学习python脚本移动到tensorflow代码中。不断地被错误缠住。请帮忙

    import pandas as pd
    import numpy as np
    import tensorflow as tf

    # read csv
    df = pd.read_csv("/Downloads/iris-2.csv", header=0)

    # get header names as array
    features = list(df.columns.values)
    label = features.pop()
    classes = len(df[label].unique())

    # encode target
    X = df[features]
    y = df[label]

    # convert feature headers into tf
    for index,value in enumerate(features):
        features[index] = tf.feature_column.numeric_column(value)

    # initialize classifier
    classifier = tf.estimator.DNNClassifier(
        feature_columns=features,
        hidden_units=[10, 10],
        n_classes=classes)

    # train the classifier
    dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
    dataset = dataset.shuffle(1000).repeat().batch(0)
    data = dataset.make_one_shot_iterator().get_next()
    classifier.train(input_fn=lambda:data,steps=3)
    predictions = classifier.predict([5.1,3.0,4.2,1.2])
    print(predictions)
我遇到的最新错误是:

ValueError: Passed Tensor("dnn/head/weighted_loss/Sum:0", shape=(), dtype=float32) should have graph attribute that is equal to current graph <tensorflow.python.framework.ops.Graph object at 0x10dd9a190>.
ValueError:Passed Tensor(“dnn/head/weighted_loss/Sum:0”,shape=(),dtype=float32)应具有与当前图形相等的图形属性。
这是我正在使用的数据集:

无法预计算输入张量(变量数据和数据集)。它们需要在调用train()时传递给input_fn的函数中进行计算,以便张量位于估计器(分类器)在调用train()期间创建的图中。因此,对于最后一个街区,您可以使用:

# train the classifier
def my_input_fn():
    dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))
    dataset = dataset.shuffle(1000).repeat().batch(0)
    return dataset.make_one_shot_iterator().get_next()
classifier.train(input_fn=my_input_fn, steps=3)
predictions = classifier.predict([5.1,3.0,4.2,1.2])
print(predictions)

是否创建了包含图形/检查点的临时目录?你试过删除那个吗?嗯。。temp目录中似乎没有任何文件。请在创建DNNClassifier时尝试指定一个文件。看看我这里的示例:如果您更改NN的形状,当尝试从检查点读回图形时,它会变得混乱,因此如果您这样做,清除temp目录总是好的。因为你没有在这里指定一个,我不确定它是从哪里读来的。