Python 调用保存的分类器时无法强制转换数组数据

Python 调用保存的分类器时无法强制转换数组数据,python,machine-learning,scikit-learn,joblib,Python,Machine Learning,Scikit Learn,Joblib,我使用示例创建了一个分类器。 为了训练分类器,我使用以下代码 import os import numpy NEWLINE = '\n' SKIP_FILES = set(['cmds']) def read_files(path): for root, dir_names, file_names in os.walk(path): for path in dir_names: read_files(os.path.join(root, path)) for

我使用示例创建了一个分类器。 为了训练分类器,我使用以下代码

import os
import numpy

NEWLINE = '\n'
SKIP_FILES = set(['cmds'])

def read_files(path):
  for root, dir_names, file_names in os.walk(path):
    for path in dir_names:
      read_files(os.path.join(root, path))
    for file_name in file_names:
      if file_name not in SKIP_FILES:
        file_path = os.path.join(root, file_name)
        if os.path.isfile(file_path):
          past_header, lines = False, []
          f = open(file_path)
          for line in f:
            if past_header:
              lines.append(line)
            elif line == NEWLINE:
              past_header = True
          f.close()
          yield file_path, NEWLINE.join(lines).decode('cp1252', 'ignore')

from pandas import DataFrame

def build_data_frame(path, classification):
  data_frame = DataFrame({'text': [], 'class': []})
  for file_name, text in read_files(path):
    data_frame = data_frame.append(
        DataFrame({'text': [text], 'class': [classification]}, index=[file_name]))
  return data_frame

HAM = 0
SPAM = 1

SOURCES = [
    ('data/spam',         SPAM),
    ('data/easy_ham',     HAM),
    ('data/hard_ham',     HAM),
    ('data/beck-s',       HAM),
    ('data/farmer-d',     HAM),
    ('data/kaminski-v',   HAM),
    ('data/kitchen-l',    HAM),
    ('data/lokay-m',      HAM),
    ('data/williams-w3',  HAM),
    ('data/BG',           SPAM),
    ('data/GP',           SPAM),
    ('data/SH',           SPAM)
    ]

data = DataFrame({'text': [], 'class': []})
for path, classification in SOURCES:
  data = data.append(build_data_frame(path, classification))
data = data.reindex(numpy.random.permutation(data.index))  

import numpy
from sklearn.feature_extraction.text import CountVectorizer

count_vectorizer = CountVectorizer()
counts = count_vectorizer.fit_transform(numpy.asarray(data['text']))

from sklearn.naive_bayes import MultinomialNB

classifier = MultinomialNB()
targets = numpy.asarray(data['class'])
clf = classifier.fit(counts, targets)

from sklearn.externals import joblib
joblib.dump(clf, 'my_trained_data.pkl', compress=9)
如果我在这个文件中测试一个示例,那么它就可以正常工作。 但我试图将分类器保存到我的_-trained_-data.pkl中,然后称之为ass-following

from sklearn.externals import joblib
clf = joblib.load('my_trained_data.pkl')

examples = ['Free Viagra call today!', "I'm going to attend the Linux users group tomorrow."]
predictions = clf.predict(examples)
这将导致以下错误

TypeError: Cannot cast array data from dtype('float64') to dtype('S32') according to the rule 'safe'
以下是跟踪

In [12]: runfile('/home/harpreet/Machine_learning/untitled0.py', wdir='/home/harpreet/Machine_learning') MultinomialNB(alpha=1.0, class_prior=None, fit_prior=True) Traceback (most recent call last):

  File "<ipython-input-12-521f3ed1e6da>", line 1, in <module>
    runfile('/home/harpreet/Machine_learning/untitled0.py', wdir='/home/harpreet/Machine_learning')

  File "/home/harpreet/anaconda/lib/python2.7/site-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 682, in runfile
    execfile(filename, namespace)

  File "/home/harpreet/anaconda/lib/python2.7/site-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 78, in execfile
    builtins.execfile(filename, *where)

  File "/home/harpreet/Machine_learning/untitled0.py", line 13, in <module>
    clf.predict(examples)

  File "/home/harpreet/anaconda/lib/python2.7/site-packages/sklearn/naive_bayes.py", line 62, in predict
    jll = self._joint_log_likelihood(X)

  File "/home/harpreet/anaconda/lib/python2.7/site-packages/sklearn/naive_bayes.py", line 441, in _joint_log_likelihood
    return (safe_sparse_dot(X, self.feature_log_prob_.T)

  File "/home/harpreet/anaconda/lib/python2.7/site-packages/sklearn/utils/extmath.py", line 180, in safe_sparse_dot
    return fast_dot(a, b)

TypeError: Cannot cast array data from dtype('float64') to dtype('S32') according to the rule 'safe'
[12]中的
runfile('/home/harpreet/Machine\u learning/untitled0.py',wdir='/home/harpreet/Machine\u learning')多项式nb(alpha=1.0,class\u prior=None,fit\u prior=True)回溯(最后一次调用):
文件“”,第1行,在
运行文件('/home/harpreet/Machine\u learning/untitled0.py',wdir='/home/harpreet/Machine\u learning')
文件“/home/harpreet/anaconda/lib/python2.7/site packages/spyderlib/widgets/externalshell/sitecustomize.py”,第682行,在runfile中
execfile(文件名、命名空间)
文件“/home/harpreet/anaconda/lib/python2.7/site packages/spyderlib/widgets/externalshell/sitecustomize.py”,第78行,在execfile中
execfile(文件名,*其中)
文件“/home/harpreet/Machine\u learning/untitled0.py”,第13行,在
clf.predict(示例)
文件“/home/harpreet/anaconda/lib/python2.7/site packages/sklearn/naive_bayes.py”,第62行,在predict中
jll=自联合对数似然(X)
文件“/home/harpreet/anaconda/lib/python2.7/site packages/sklearn/naive_bayes.py”,第441行,在联合日志中
返回(安全稀疏点(X,自特性,日志,问题)
文件“/home/harpreet/anaconda/lib/python2.7/site packages/sklearn/utils/extmath.py”,第180行,安全稀疏点
快速返回点(a,b)
TypeError:无法根据“安全”规则将数组数据从dtype('float64')强制转换为dtype('S32')

您需要使用相同的
矢量器转换测试文档
实例:

examples_vectors = count_vectorizer.transform(examples)
clf.predict(examples_vectors)
一般来说,使用管道更容易:

from sklearn.pipeline import make_pipeline

pipeline = make_pipeline(CountVectorizer(), MultinomialNB())
pipeline.fit(data['text'].values, data['class'].values)
随后:

pipeline.predict(examples)

请包含错误消息的完整回溯。添加了回溯@ogrisel