Ruby 逻辑回归给出了错误的结果
我正在一个网站上工作,在那里我收集人们下棋的结果。查看玩家的评分以及他们的评分与对手的评分之间的差异,我绘制了一张图,图中用点表示赢(绿色)、平(蓝色)和输(红色) 有了这些信息,我还实现了一个逻辑回归算法来对获胜和获胜/平局的截止值进行分类。使用评级和差异作为我的两个特征,我得到一个分类器,然后在图表上画出分类器更改其预测的边界 下面是我的梯度下降、代价函数和S形函数的代码Ruby 逻辑回归给出了错误的结果,ruby,algorithm,statistics,machine-learning,regression,Ruby,Algorithm,Statistics,Machine Learning,Regression,我正在一个网站上工作,在那里我收集人们下棋的结果。查看玩家的评分以及他们的评分与对手的评分之间的差异,我绘制了一张图,图中用点表示赢(绿色)、平(蓝色)和输(红色) 有了这些信息,我还实现了一个逻辑回归算法来对获胜和获胜/平局的截止值进行分类。使用评级和差异作为我的两个特征,我得到一个分类器,然后在图表上画出分类器更改其预测的边界 下面是我的梯度下降、代价函数和S形函数的代码 def gradient_descent() oldJ = 0 newJ = J()
def gradient_descent()
oldJ = 0
newJ = J()
alpha = 1.0 # Learning rate
run = 0
while (run < 100) do
tmpTheta = Array.new
for j in 0...numFeatures do
sum = 0
for i in 0...m do
sum += ((h(training_data[:x][i]) - training_data[:y][i][0]) * training_data[:x][i][j])
end
tmpTheta[j] = Array.new
tmpTheta[j][0] = theta[j, 0] - (alpha / m) * sum # Alpha * partial derivative of J with respect to theta_j
end
self.theta = Matrix.rows(tmpTheta)
oldJ = newJ
newJ = J()
run += 1
if (run == 100 && (oldJ - newJ > 0.001)) then run -= 20 end # Do 20 more if the error is still going down a fair amount.
if (oldJ < newJ)
alpha /= 10
end
end
end
def J()
sum = 0
for i in 0...m
sum += ((training_data[:y][i][0] * Math.log(h(training_data[:x][i])))
+ ((1 - training_data[:y][i][0]) * Math.log(1 - h(training_data[:x][i]))))
end
return (-1.0 / m) * sum
end
def h(x)
if (x.class != 'Matrix') # In case it comes in as a row vector or an array
x = Matrix.rows([x]) # [x] because if it's a row vector we want [[a, b]] to get an array whose first row is x.
end
x = x.transpose # x is supposed to be a column vector, and theta^ a row vector, so theta^*x is a number.
return g((theta.transpose * x)[0, 0]) # theta^ * x gives [[z]], so get [0, 0] of that for the number z.
end
def g(z)
tmp = 1.0 / (1.0 + Math.exp(-z)) # Sigmoid function
if (tmp == 1.0) then tmp = 0.99999 end # These two things are here because ln(0) DNE, so we don't want to do ln(1 - 1.0) or ln(0.0)
if (tmp == 0.0) then tmp = 0.00001 end
return tmp
end
对于创建正确预测的对象:
J: 4.330234652497978 Alpha: 1.0 Theta: Matrix[[0.12388059701492538], [211.9910447761194], [-111.13731343283582]]
J: 4.330234652497978 Alpha: 0.1 Theta: Matrix[[0.08626965671641812], [152.3222144059701], [-118.07202388059702]]
J: 4.2958677406623815 Alpha: 0.1 Theta: Matrix[[0.048658716417910856], [92.65338403582082], [-125.0067343283582]]
J: 3.333594209265678 Alpha: 0.1 Theta: Matrix[[0.011644779104478219], [33.61767533134318], [-131.44443979104477]]
J: 0.4467735852246924 Alpha: 0.1 Theta: Matrix[[-0.014623104477611202], [-11.126378913433022], [-132.24166105074627]]
J: 3.333594209265678 Alpha: 0.1 Theta: Matrix[[0.01194378805970217], [31.177094038805805], [-126.89243925671643]]
J: 3.0930257965656063 Alpha: 0.01 Theta: Matrix[[0.009436400895523079], [26.892626149850567], [-126.92472924]]
J: 2.7493567080605392 Alpha: 0.01 Theta: Matrix[[0.007257365074627634], [23.13644550388053], [-126.8386038647761]]
J: 2.508788325211366 Alpha: 0.01 Theta: Matrix[[0.005466380895523164], [19.99261048238799], [-126.62851089164178]]
J: 2.405687589704577 Alpha: 0.01 Theta: Matrix[[0.004152999104478391], [17.61296913194023], [-126.28907722179103]]
J: 2.268219942362192 Alpha: 0.01 Theta: Matrix[[0.002959017910448543], [15.415473392238736], [-125.92224111492536]]
J: 2.1307522353180164 Alpha: 0.01 Theta: Matrix[[0.002093389253732125], [13.751072827761122], [-125.48597339134326]]
J: 2.027651529662123 Alpha: 0.01 Theta: Matrix[[0.0014367116417918252], [12.436814710149182], [-125.00961691402983]]
J: 1.9589177059909308 Alpha: 0.01 Theta: Matrix[[0.0009889847761201823], [11.44908667850739], [-124.49911195194028]]
J: 1.8558169406332465 Alpha: 0.01 Theta: Matrix[[0.0006606582089560022], [10.652638055522315], [-123.97004023522386]]
J: 1.8214500586485458 Alpha: 0.01 Theta: Matrix[[0.0004218823880604789], [9.988664770447688], [-123.42914782925371]]
J: 1.8214500884994413 Alpha: 0.01 Theta: Matrix[[0.0002428068653197179], [9.416182220312082], [-122.88082274064425]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00023086931308091184], [9.369775500013574], [-122.82513353589798]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00021893176084210577], [9.323368779715066], [-122.7694443311517]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.0002069942086032997], [9.276962059416558], [-122.71375512640543]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00019505665636449364], [9.23055533911805], [-122.65806592165916]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00018311910412568757], [9.184148618819542], [-122.60237671691289]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.0001711815518868815], [9.137741898521034], [-122.54668751216661]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00015924399964807544], [9.091335178222526], [-122.49099830742034]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00014730641755852312], [9.04492840598372], [-122.43530910670393]]
J: 1.8677695240029366 Alpha: 0.001 Theta: Matrix[[0.0001353688354689708], [8.998521633744915], [-122.37961990598751]]
J: 1.8462563443835032 Alpha: 0.0001 Theta: Matrix[[0.0001341750742749415], [8.993880951437452], [-122.374050986289]]
J: 1.8247430163841476 Alpha: 0.0001 Theta: Matrix[[0.00013298131308164604], [8.98924026913124], [-122.3684820665904]]
J: 1.803243007740144 Alpha: 0.0001 Theta: Matrix[[0.0001317875528781551], [8.984599588510665], [-122.36291314676808]]
J: 1.7875423426167685 Alpha: 0.0001 Theta: Matrix[[0.00013059512176735966], [8.979961171334951], [-122.35734406080917]]
J: 1.7870839229503594 Alpha: 0.0001 Theta: Matrix[[0.0001296573060241053], [8.97575636413016], [-122.35174314792931]]
J: 1.7870831481868632 Alpha: 0.0001 Theta: Matrix[[0.00012876197468911015], [8.971623907872633], [-122.34613692449842]]
J: 1.7870831468153818 Alpha: 0.0001 Theta: Matrix[[0.00012786672082037553], [8.967491583540149], [-122.34053069138426]]
J: 1.7870831468129538 Alpha: 0.0001 Theta: Matrix[[0.000126971467088789], [8.963359259441226], [-122.33492445825294]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.0001260762133574453], [8.959226935342718], [-122.3293182251216]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012518095962610202], [8.95509461124421], [-122.32371199199025]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012428570589475874], [8.950962287145702], [-122.3181057588589]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012339045216341546], [8.946829963047193], [-122.31249952572756]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012249519843207218], [8.942697638948685], [-122.30689329259621]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012159994470072888], [8.938565314850177], [-122.30128705946487]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012070469096938559], [8.934432990751668], [-122.29568082633352]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.0001198094372380423], [8.93030066665316], [-122.29007459320218]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.000118914183506699], [8.926168342554652], [-122.28446836007083]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00011801892977535571], [8.922036018456144], [-122.27886212693949]]
......
编辑:我注意到假设的第一次迭代总是预测0.5,因为θ都是0。但之后它总是预测1或0(0.00001或0.99999以避免我的代码中不存在的对数)。对我来说,这似乎不正确——太自信了——这可能是这不起作用的关键所在。浮点数:
试试这个<代码>相等比较浮点数对我来说意义不大
def g(z)
tmp=1.0/(1.0+Math.exp(-z))#Sigmoid函数
如果(tmp>=0.99999),那么tmp=0.99999 end#这两件事在这里是因为ln(0)DNE,所以我们不想做ln(1-1.0)或ln(0.0)
如果(tmp有一些想法:
- 我认为如果您可以显示
J()
和alpha
的某些迭代的值,这将非常有用
- 你是否将常数(偏差)作为特征?如果我没记错,如果你不这样做,你的
h()==0.5的(直线)将被迫通过零
- 你的函数
J()
看起来像是返回了负的对数可能性(因此你希望最小化)。但是如果(oldJ
,即J()
变得更大,即更糟,你会降低学习率
关于您的实现,有几点不标准
lr(x[n],y[n])=log(1+exp(-y[n]*dot(w[n],x[n]))
其中y[n]
是1
或-1
您似乎正在使用等效的最大化问题公式
lr(x[n],y[n])=-y[n]*log(1+exp(-dot(w[n],x[n]))+(1-y[n])*(-dot(w[n],x[n])-log(1+exp(-dot(w[n],x[n]))
其中,y[n]
为0或1(此公式中的y[n]=0相当于第一个公式中的y[n]=1)
因此,您应该确保在数据集中,标签是0或1,而不是1或-1m
(数据集的大小)。当您将逻辑回归视为概率模型时,此比例因子是不正确的/*logloss(a,y)=log(1+exp(-a*y))*/
双损失(双a,双y)
{
双z=a*y;
如果(z>18){
返回exp(-z);
}
如果(z<-18){
返回-z;
}
返回日志(1+exp(-z));
}
/*-dloss(a,y)/da*/
双数据丢失(双a,双y)
{
双z=a*y;
如果(z>18){
返回y*exp(-z);
}
如果(z<-18){
返回y;
}
返回y/(1+exp(z));
}
你也应该考虑运行一个股票-BFGS程序(我不熟悉Ruby实现)因此,您可以专注于获得正确的目标和梯度计算,而不必担心学习率之类的问题。
我认为您需要使用特征规范化((X-mu)/sigma)对初始数据集进行规范化,然后执行您打算执行的操作如果没有功能规范化,gradient descent将成为大型数据集的一个错误,因为它的行为异常。考虑在上发布此问题。如果主持人可以将其移动到那里,我将不胜感激。我已将其标记为离题,以便可以查看和移动。我唯一能说的是离题-这是非常糟糕的ruby代码。请使用类似于E<代码>每个和其他收集方法,如<代码> map < /COD>,<代码>注销< /代码>,ETCMY建议帮助您发现问题,将您的输出与预期正确的实现进行比较。在这种情况下,请考虑<代码> MNRFIT < /C>()使用等号和浮点数是因为如果它们正好等于1.0或0.0,那么你会得到一个log(0)当你做1.0-n或0.0-n时出错。因此,如果它们不完全相等,那么完全相等就可以了。很少有时候可以使用相等来比较浮动。我在这里使用了多类逻辑回归,因为有三种可能的结果,将它们分为两组。(赢,平局)=1和(赢)=1。因此,蓝色代表第一组的截止点,绿色代表第二组。蓝色更高,因为它在y=1类别中包含更多结果。为什么使用功能缩放或固定学习率会有帮助?我不是批评,只是试图理解。我回家后会尝试。使用固定学习率可能会有帮助调试。这只是我的2美分。我假设评级差异与评级本身在显著不同的分布中,所以从梯度下降的角度来看,使用特征尺度可能会给你一个更好的轮廓(“更好”意味着轮廓在你的特征维度中不是很高或很宽)。它实际上不是一个显著不同的分布
J: 4.330234652497978 Alpha: 1.0 Theta: Matrix[[0.12388059701492538], [211.9910447761194], [-111.13731343283582]]
J: 4.330234652497978 Alpha: 0.1 Theta: Matrix[[0.08626965671641812], [152.3222144059701], [-118.07202388059702]]
J: 4.2958677406623815 Alpha: 0.1 Theta: Matrix[[0.048658716417910856], [92.65338403582082], [-125.0067343283582]]
J: 3.333594209265678 Alpha: 0.1 Theta: Matrix[[0.011644779104478219], [33.61767533134318], [-131.44443979104477]]
J: 0.4467735852246924 Alpha: 0.1 Theta: Matrix[[-0.014623104477611202], [-11.126378913433022], [-132.24166105074627]]
J: 3.333594209265678 Alpha: 0.1 Theta: Matrix[[0.01194378805970217], [31.177094038805805], [-126.89243925671643]]
J: 3.0930257965656063 Alpha: 0.01 Theta: Matrix[[0.009436400895523079], [26.892626149850567], [-126.92472924]]
J: 2.7493567080605392 Alpha: 0.01 Theta: Matrix[[0.007257365074627634], [23.13644550388053], [-126.8386038647761]]
J: 2.508788325211366 Alpha: 0.01 Theta: Matrix[[0.005466380895523164], [19.99261048238799], [-126.62851089164178]]
J: 2.405687589704577 Alpha: 0.01 Theta: Matrix[[0.004152999104478391], [17.61296913194023], [-126.28907722179103]]
J: 2.268219942362192 Alpha: 0.01 Theta: Matrix[[0.002959017910448543], [15.415473392238736], [-125.92224111492536]]
J: 2.1307522353180164 Alpha: 0.01 Theta: Matrix[[0.002093389253732125], [13.751072827761122], [-125.48597339134326]]
J: 2.027651529662123 Alpha: 0.01 Theta: Matrix[[0.0014367116417918252], [12.436814710149182], [-125.00961691402983]]
J: 1.9589177059909308 Alpha: 0.01 Theta: Matrix[[0.0009889847761201823], [11.44908667850739], [-124.49911195194028]]
J: 1.8558169406332465 Alpha: 0.01 Theta: Matrix[[0.0006606582089560022], [10.652638055522315], [-123.97004023522386]]
J: 1.8214500586485458 Alpha: 0.01 Theta: Matrix[[0.0004218823880604789], [9.988664770447688], [-123.42914782925371]]
J: 1.8214500884994413 Alpha: 0.01 Theta: Matrix[[0.0002428068653197179], [9.416182220312082], [-122.88082274064425]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00023086931308091184], [9.369775500013574], [-122.82513353589798]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00021893176084210577], [9.323368779715066], [-122.7694443311517]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.0002069942086032997], [9.276962059416558], [-122.71375512640543]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00019505665636449364], [9.23055533911805], [-122.65806592165916]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00018311910412568757], [9.184148618819542], [-122.60237671691289]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.0001711815518868815], [9.137741898521034], [-122.54668751216661]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00015924399964807544], [9.091335178222526], [-122.49099830742034]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00014730641755852312], [9.04492840598372], [-122.43530910670393]]
J: 1.8677695240029366 Alpha: 0.001 Theta: Matrix[[0.0001353688354689708], [8.998521633744915], [-122.37961990598751]]
J: 1.8462563443835032 Alpha: 0.0001 Theta: Matrix[[0.0001341750742749415], [8.993880951437452], [-122.374050986289]]
J: 1.8247430163841476 Alpha: 0.0001 Theta: Matrix[[0.00013298131308164604], [8.98924026913124], [-122.3684820665904]]
J: 1.803243007740144 Alpha: 0.0001 Theta: Matrix[[0.0001317875528781551], [8.984599588510665], [-122.36291314676808]]
J: 1.7875423426167685 Alpha: 0.0001 Theta: Matrix[[0.00013059512176735966], [8.979961171334951], [-122.35734406080917]]
J: 1.7870839229503594 Alpha: 0.0001 Theta: Matrix[[0.0001296573060241053], [8.97575636413016], [-122.35174314792931]]
J: 1.7870831481868632 Alpha: 0.0001 Theta: Matrix[[0.00012876197468911015], [8.971623907872633], [-122.34613692449842]]
J: 1.7870831468153818 Alpha: 0.0001 Theta: Matrix[[0.00012786672082037553], [8.967491583540149], [-122.34053069138426]]
J: 1.7870831468129538 Alpha: 0.0001 Theta: Matrix[[0.000126971467088789], [8.963359259441226], [-122.33492445825294]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.0001260762133574453], [8.959226935342718], [-122.3293182251216]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012518095962610202], [8.95509461124421], [-122.32371199199025]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012428570589475874], [8.950962287145702], [-122.3181057588589]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012339045216341546], [8.946829963047193], [-122.31249952572756]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012249519843207218], [8.942697638948685], [-122.30689329259621]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012159994470072888], [8.938565314850177], [-122.30128705946487]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012070469096938559], [8.934432990751668], [-122.29568082633352]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.0001198094372380423], [8.93030066665316], [-122.29007459320218]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.000118914183506699], [8.926168342554652], [-122.28446836007083]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00011801892977535571], [8.922036018456144], [-122.27886212693949]]
......
/* logloss(a,y) = log(1+exp(-a*y)) */
double loss(double a, double y)
{
double z = a * y;
if (z > 18) {
return exp(-z);
}
if (z < -18) {
return -z;
}
return log(1 + exp(-z));
}
/* -dloss(a,y)/da */
double dloss(double a, double y)
{
double z = a * y;
if (z > 18) {
return y * exp(-z);
}
if (z < -18){
return y;
}
return y / (1 + exp(z));
}