Python TF-IDF的实现
我想知道为什么我的TF-IDF实现与sklearn实现的结果略有不同 以下是我的实现:Python TF-IDF的实现,python,pandas,scikit-learn,tf-idf,tfidfvectorizer,Python,Pandas,Scikit Learn,Tf Idf,Tfidfvectorizer,我想知道为什么我的TF-IDF实现与sklearn实现的结果略有不同 以下是我的实现: text = ["aa bb cc dd ee", "bb cc dd dd"] terms = [Counter(t.split(' ')) for t in text] tf = pd.DataFrame(terms) tf = tf.fillna(0) num_docs = len(text) idf = np.log(num_docs / tf[tf >= 1].count()) + 1
text = ["aa bb cc dd ee", "bb cc dd dd"]
terms = [Counter(t.split(' ')) for t in text]
tf = pd.DataFrame(terms)
tf = tf.fillna(0)
num_docs = len(text)
idf = np.log(num_docs / tf[tf >= 1].count()) + 1
tf_idf = tf * idf
norm = np.sqrt((tf_idf ** 2).sum(axis=1))
norm_tf_idf = tf_idf.div(norm, axis=0)
>>> norm_tf_idf
aa bb cc dd ee
0 0.572929 0.338381 0.338381 0.338381 0.572929
1 0.000000 0.408248 0.408248 0.816497 0.000000
但是,如果我使用sklearn:
tf = TfidfVectorizer(smooth_idf=False, stop_words=None, sublinear_tf=True)
x = tf.fit_transform(text)
sk = pd.DataFrame(x.toarray())
sk.columns = tf.get_feature_names()
sk
>>> sk
aa bb cc dd ee
0 0.572929 0.338381 0.338381 0.338381 0.572929
1 0.000000 0.453295 0.453295 0.767495 0.000000
或者如果我们减去它们:
>>> norm_tf_idf - sk
aa bb cc dd ee
0 0.0 0.000000 0.000000 0.000000 0.0
1 0.0 -0.045046 -0.045046 0.049002 0.0
编辑: 我发现sklearn idf与我的idf不完全相同,但我们可以将其归因于浮点精度,我认为:
sklearn idf: [1.69314718 1. 1. 1. 1.69314718]
my idf: [1.693147 1.000000 1.000000 1.000000 1.693147]
即使我使用sklearn idf,我仍然会得到不同的结果
此外,如果我不规范化并使用sklearn idf值,则只有第二个文档的dd
的TF-idf不同:
sk_tfv = TfidfVectorizer(smooth_idf=False, stop_words=None, token_pattern=r"(?u)\b\w+\b", sublinear_tf=True, norm=None)
x = sk_tf.fit_transform(text)
sk_tf_idf = pd.DataFrame(x.toarray())
...
idf = sk_tfv.idf_
tf_idf = tf * idf
>>> tf_idf - sk_tf_idf
aa bb cc dd ee
0 0.0 0.0 0.0 0.000000 0.0
1 0.0 0.0 0.0 0.306853 0.0
这意味着两件事之一:1.问题是我的TF。然而,这是很容易检查,似乎不是这样。或者,
2.sklearn不仅仅做TF*IDF,还做了更多的事情,我必须仔细研究。我很愚蠢。在深入sklearn源代码之后,我注意到
sublinear\u tf
参数。将此参数设置为True时,术语频率将被log(TF)+1
替换,而我恰好将此参数设置为True
:)
为了在熊猫中实施次线性TF,这应该起作用:
tf[tf > 0] = np.log(tf[tf > 0] ) + 1