Tensorflow 虹膜数据集的预测
我有一个爱尔兰数据集的基本分类代码Tensorflow 虹膜数据集的预测,tensorflow,deep-learning,Tensorflow,Deep Learning,我有一个爱尔兰数据集的基本分类代码 import tensorflow as tf import pandas as pd COLUMN_NAMES = [ 'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species' ] # Import training dataset training_dataset = pd.
import tensorflow as tf
import pandas as pd
COLUMN_NAMES = [
'SepalLength',
'SepalWidth',
'PetalLength',
'PetalWidth',
'Species'
]
# Import training dataset
training_dataset = pd.read_csv('iris_training.csv', names=COLUMN_NAMES, header=0)
train_x = training_dataset.iloc[:, 0:4]
train_y = training_dataset.iloc[:, 4]
# Import testing dataset
test_dataset = pd.read_csv('iris_test.csv', names=COLUMN_NAMES, header=0)
test_x = test_dataset.iloc[:, 0:4]
test_y = test_dataset.iloc[:, 4]
columns_feat = [
tf.feature_column.numeric_column(key='SepalLength'),
tf.feature_column.numeric_column(key='SepalWidth'),
tf.feature_column.numeric_column(key='PetalLength'),
tf.feature_column.numeric_column(key='PetalWidth')
]
classifier = tf.estimator.DNNClassifier(
feature_columns=columns_feat,
# Two hidden layers of 10 nodes each.
hidden_units=[10, 10],
# The model is classifying 3 classes
n_classes=3)
def train_function(inputs, outputs, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((dict(inputs), outputs))
dataset = dataset.shuffle(1000).repeat().batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
# Train the Model.
classifier.train(
input_fn=lambda:train_function(train_x, train_y, 100),
steps=1000)
def evaluation_function(attributes, classes, batch_size):
attributes=dict(attributes)
if classes is None:
inputs = attributes
else:
inputs = (attributes, classes)
dataset = tf.data.Dataset.from_tensor_slices(inputs)
assert batch_size is not None, "batch_size must not be None"
dataset = dataset.batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
# Evaluate the model.
eval_result = classifier.evaluate(
input_fn=lambda:evaluation_function(test_x, test_y, 100))
我对结果进行了评估,但我如何才能对我的数据进行预测,因为现在我只得到损失和年代的控制台信息,准确性。例如,如果我什么都有,除了物种。我想给出我自己的萼片长度等,这样我就可以预测物种,这将是另一个变量。我是否必须创建诸如pred_x或pred_y(熊猫数据帧)之类的变量,然后将它们放入评估结果中?与所有估计器类一样,
DNNClassifier
类有一个predict
方法来进行真实世界的预测。文档是。这就是你的意思吗?例如:new\u samples=np.array([[6.4,3.2,4.5,1.5],[5.8,3.1,5.0,1.7]],dtype=np.float32)
如果您希望这样的新数据进行预测,那么您可以参考此代码