Ruby 逻辑回归给出了错误的结果

Ruby 逻辑回归给出了错误的结果,ruby,algorithm,statistics,machine-learning,regression,Ruby,Algorithm,Statistics,Machine Learning,Regression,我正在一个网站上工作,在那里我收集人们下棋的结果。查看玩家的评分以及他们的评分与对手的评分之间的差异,我绘制了一张图,图中用点表示赢(绿色)、平(蓝色)和输(红色) 有了这些信息,我还实现了一个逻辑回归算法来对获胜和获胜/平局的截止值进行分类。使用评级和差异作为我的两个特征,我得到一个分类器,然后在图表上画出分类器更改其预测的边界 下面是我的梯度下降、代价函数和S形函数的代码 def gradient_descent() oldJ = 0 newJ = J()

我正在一个网站上工作,在那里我收集人们下棋的结果。查看玩家的评分以及他们的评分与对手的评分之间的差异,我绘制了一张图,图中用点表示赢(绿色)、平(蓝色)和输(红色)

有了这些信息,我还实现了一个逻辑回归算法来对获胜和获胜/平局的截止值进行分类。使用评级和差异作为我的两个特征,我得到一个分类器,然后在图表上画出分类器更改其预测的边界

下面是我的梯度下降、代价函数和S形函数的代码

  def gradient_descent()
    oldJ = 0    
    newJ = J()
    alpha = 1.0     # Learning rate
    run = 0
    while (run < 100) do
      tmpTheta = Array.new
      for j in 0...numFeatures do
        sum = 0
        for i in 0...m do
          sum += ((h(training_data[:x][i]) - training_data[:y][i][0]) * training_data[:x][i][j])
        end
        tmpTheta[j] = Array.new
        tmpTheta[j][0] = theta[j, 0] - (alpha / m) * sum  # Alpha * partial derivative of J with respect to theta_j
      end
      self.theta = Matrix.rows(tmpTheta)
      oldJ = newJ
      newJ = J()
      run += 1
      if (run == 100 && (oldJ - newJ > 0.001)) then run -= 20 end   # Do 20 more if the error is still going down a fair amount.
      if (oldJ < newJ)
        alpha /= 10
      end
    end
  end

  def J()
    sum = 0
    for i in 0...m
      sum += ((training_data[:y][i][0] * Math.log(h(training_data[:x][i]))) 
          + ((1 - training_data[:y][i][0]) * Math.log(1 - h(training_data[:x][i]))))
    end
    return (-1.0 / m) * sum
  end

  def h(x)
    if (x.class != 'Matrix')    # In case it comes in as a row vector or an array
      x = Matrix.rows([x])      # [x] because if it's a row vector we want [[a, b]] to get an array whose first row is x.
    end
    x = x.transpose   # x is supposed to be a column vector, and theta^ a row vector, so theta^*x is a number.
    return g((theta.transpose * x)[0, 0])  # theta^ * x gives [[z]], so get [0, 0] of that for the number z.
  end

  def g(z)
    tmp = 1.0 / (1.0 + Math.exp(-z))   # Sigmoid function
    if (tmp == 1.0) then tmp = 0.99999 end    # These two things are here because ln(0) DNE, so we don't want to do ln(1 - 1.0) or ln(0.0)
    if (tmp == 0.0) then tmp = 0.00001 end
    return tmp
  end
对于创建正确预测的对象:

J: 4.330234652497978  Alpha: 1.0  Theta: Matrix[[0.12388059701492538], [211.9910447761194], [-111.13731343283582]]
J: 4.330234652497978  Alpha: 0.1  Theta: Matrix[[0.08626965671641812], [152.3222144059701], [-118.07202388059702]]
J: 4.2958677406623815  Alpha: 0.1  Theta: Matrix[[0.048658716417910856], [92.65338403582082], [-125.0067343283582]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.011644779104478219], [33.61767533134318], [-131.44443979104477]]
J: 0.4467735852246924  Alpha: 0.1  Theta: Matrix[[-0.014623104477611202], [-11.126378913433022], [-132.24166105074627]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.01194378805970217], [31.177094038805805], [-126.89243925671643]]
J: 3.0930257965656063  Alpha: 0.01  Theta: Matrix[[0.009436400895523079], [26.892626149850567], [-126.92472924]]
J: 2.7493567080605392  Alpha: 0.01  Theta: Matrix[[0.007257365074627634], [23.13644550388053], [-126.8386038647761]]
J: 2.508788325211366  Alpha: 0.01  Theta: Matrix[[0.005466380895523164], [19.99261048238799], [-126.62851089164178]]
J: 2.405687589704577  Alpha: 0.01  Theta: Matrix[[0.004152999104478391], [17.61296913194023], [-126.28907722179103]]
J: 2.268219942362192  Alpha: 0.01  Theta: Matrix[[0.002959017910448543], [15.415473392238736], [-125.92224111492536]]
J: 2.1307522353180164  Alpha: 0.01  Theta: Matrix[[0.002093389253732125], [13.751072827761122], [-125.48597339134326]]
J: 2.027651529662123  Alpha: 0.01  Theta: Matrix[[0.0014367116417918252], [12.436814710149182], [-125.00961691402983]]
J: 1.9589177059909308  Alpha: 0.01  Theta: Matrix[[0.0009889847761201823], [11.44908667850739], [-124.49911195194028]]
J: 1.8558169406332465  Alpha: 0.01  Theta: Matrix[[0.0006606582089560022], [10.652638055522315], [-123.97004023522386]]
J: 1.8214500586485458  Alpha: 0.01  Theta: Matrix[[0.0004218823880604789], [9.988664770447688], [-123.42914782925371]]
J: 1.8214500884994413  Alpha: 0.01  Theta: Matrix[[0.0002428068653197179], [9.416182220312082], [-122.88082274064425]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00023086931308091184], [9.369775500013574], [-122.82513353589798]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00021893176084210577], [9.323368779715066], [-122.7694443311517]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0002069942086032997], [9.276962059416558], [-122.71375512640543]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00019505665636449364], [9.23055533911805], [-122.65806592165916]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00018311910412568757], [9.184148618819542], [-122.60237671691289]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0001711815518868815], [9.137741898521034], [-122.54668751216661]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00015924399964807544], [9.091335178222526], [-122.49099830742034]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00014730641755852312], [9.04492840598372], [-122.43530910670393]]
J: 1.8677695240029366  Alpha: 0.001  Theta: Matrix[[0.0001353688354689708], [8.998521633744915], [-122.37961990598751]]
J: 1.8462563443835032  Alpha: 0.0001  Theta: Matrix[[0.0001341750742749415], [8.993880951437452], [-122.374050986289]]
J: 1.8247430163841476  Alpha: 0.0001  Theta: Matrix[[0.00013298131308164604], [8.98924026913124], [-122.3684820665904]]
J: 1.803243007740144  Alpha: 0.0001  Theta: Matrix[[0.0001317875528781551], [8.984599588510665], [-122.36291314676808]]
J: 1.7875423426167685  Alpha: 0.0001  Theta: Matrix[[0.00013059512176735966], [8.979961171334951], [-122.35734406080917]]
J: 1.7870839229503594  Alpha: 0.0001  Theta: Matrix[[0.0001296573060241053], [8.97575636413016], [-122.35174314792931]]
J: 1.7870831481868632  Alpha: 0.0001  Theta: Matrix[[0.00012876197468911015], [8.971623907872633], [-122.34613692449842]]
J: 1.7870831468153818  Alpha: 0.0001  Theta: Matrix[[0.00012786672082037553], [8.967491583540149], [-122.34053069138426]]
J: 1.7870831468129538  Alpha: 0.0001  Theta: Matrix[[0.000126971467088789], [8.963359259441226], [-122.33492445825294]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001260762133574453], [8.959226935342718], [-122.3293182251216]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012518095962610202], [8.95509461124421], [-122.32371199199025]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012428570589475874], [8.950962287145702], [-122.3181057588589]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012339045216341546], [8.946829963047193], [-122.31249952572756]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012249519843207218], [8.942697638948685], [-122.30689329259621]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012159994470072888], [8.938565314850177], [-122.30128705946487]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012070469096938559], [8.934432990751668], [-122.29568082633352]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001198094372380423], [8.93030066665316], [-122.29007459320218]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.000118914183506699], [8.926168342554652], [-122.28446836007083]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00011801892977535571], [8.922036018456144], [-122.27886212693949]]
......
编辑:我注意到假设的第一次迭代总是预测0.5,因为θ都是0。但之后它总是预测1或0(0.00001或0.99999以避免我的代码中不存在的对数)。对我来说,这似乎不正确——太自信了——这可能是这不起作用的关键所在。

浮点数: 试试这个<代码>相等比较浮点数对我来说意义不大

def g(z)
tmp=1.0/(1.0+Math.exp(-z))#Sigmoid函数
如果(tmp>=0.99999),那么tmp=0.99999 end#这两件事在这里是因为ln(0)DNE,所以我们不想做ln(1-1.0)或ln(0.0)
如果(tmp有一些想法:

  • 我认为如果您可以显示
    J()
    alpha
    的某些迭代的值,这将非常有用
  • 你是否将常数(偏差)作为特征?如果我没记错,如果你不这样做,你的
    h()==0.5的(直线)将被迫通过零

  • 你的函数
    J()
    看起来像是返回了负的对数可能性(因此你希望最小化)。但是如果(oldJ
,即
J()
变得更大,即更糟,你会降低学习率


关于您的实现,有几点不标准

  • 首先,logistic回归目标通常作为一个最小化问题给出

    lr(x[n],y[n])=log(1+exp(-y[n]*dot(w[n],x[n]))
    其中
    y[n]
    1
    -1

    您似乎正在使用等效的最大化问题公式

    lr(x[n],y[n])=-y[n]*log(1+exp(-dot(w[n],x[n]))+(1-y[n])*(-dot(w[n],x[n])-log(1+exp(-dot(w[n],x[n]))

    其中,
    y[n]
    为0或1(此公式中的y[n]=0相当于第一个公式中的y[n]=1)

    因此,您应该确保在数据集中,标签是0或1,而不是1或-1

  • 接下来,LR目标通常不会除以
    m
    (数据集的大小)。当您将逻辑回归视为概率模型时,此比例因子是不正确的

  • 最后,您的实现可能存在一些数字问题(您试图在g函数中纠正这些问题)(http://leon.bottou.org/projects/sgd)对损失函数和导数进行了如下更稳定的计算(在C代码中,他使用了我提到的第一个LR公式):

    /*logloss(a,y)=log(1+exp(-a*y))*/
    双损失(双a,双y)
    {
    双z=a*y;
    如果(z>18){
    返回exp(-z);
    }
    如果(z<-18){
    返回-z;
    }
    返回日志(1+exp(-z));
    }
    /*-dloss(a,y)/da*/
    双数据丢失(双a,双y)
    {
    双z=a*y;
    如果(z>18){
    返回y*exp(-z);
    }
    如果(z<-18){
    返回y;
    }
    返回y/(1+exp(z));
    }
    

  • 你也应该考虑运行一个股票-BFGS程序(我不熟悉Ruby实现)因此,您可以专注于获得正确的目标和梯度计算,而不必担心学习率之类的问题。

    我认为您需要使用特征规范化((X-mu)/sigma)对初始数据集进行规范化,然后执行您打算执行的操作


    如果没有功能规范化,gradient descent将成为大型数据集的一个错误,因为它的行为异常。

    考虑在上发布此问题。如果主持人可以将其移动到那里,我将不胜感激。我已将其标记为离题,以便可以查看和移动。我唯一能说的是离题-这是非常糟糕的ruby代码。请使用类似于E<代码>每个和其他收集方法,如<代码> map < /COD>,<代码>注销< /代码>,ETCMY建议帮助您发现问题,将您的输出与预期正确的实现进行比较。在这种情况下,请考虑<代码> MNRFIT < /C>()使用等号和浮点数是因为如果它们正好等于1.0或0.0,那么你会得到一个log(0)当你做1.0-n或0.0-n时出错。因此,如果它们不完全相等,那么完全相等就可以了。很少有时候可以使用相等来比较浮动。我在这里使用了多类逻辑回归,因为有三种可能的结果,将它们分为两组。(赢,平局)=1和(赢)=1。因此,蓝色代表第一组的截止点,绿色代表第二组。蓝色更高,因为它在y=1类别中包含更多结果。为什么使用功能缩放或固定学习率会有帮助?我不是批评,只是试图理解。我回家后会尝试。使用固定学习率可能会有帮助调试。这只是我的2美分。我假设评级差异与评级本身在显著不同的分布中,所以从梯度下降的角度来看,使用特征尺度可能会给你一个更好的轮廓(“更好”意味着轮廓在你的特征维度中不是很高或很宽)。它实际上不是一个显著不同的分布
    J: 4.330234652497978  Alpha: 1.0  Theta: Matrix[[0.12388059701492538], [211.9910447761194], [-111.13731343283582]]
    J: 4.330234652497978  Alpha: 0.1  Theta: Matrix[[0.08626965671641812], [152.3222144059701], [-118.07202388059702]]
    J: 4.2958677406623815  Alpha: 0.1  Theta: Matrix[[0.048658716417910856], [92.65338403582082], [-125.0067343283582]]
    J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.011644779104478219], [33.61767533134318], [-131.44443979104477]]
    J: 0.4467735852246924  Alpha: 0.1  Theta: Matrix[[-0.014623104477611202], [-11.126378913433022], [-132.24166105074627]]
    J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.01194378805970217], [31.177094038805805], [-126.89243925671643]]
    J: 3.0930257965656063  Alpha: 0.01  Theta: Matrix[[0.009436400895523079], [26.892626149850567], [-126.92472924]]
    J: 2.7493567080605392  Alpha: 0.01  Theta: Matrix[[0.007257365074627634], [23.13644550388053], [-126.8386038647761]]
    J: 2.508788325211366  Alpha: 0.01  Theta: Matrix[[0.005466380895523164], [19.99261048238799], [-126.62851089164178]]
    J: 2.405687589704577  Alpha: 0.01  Theta: Matrix[[0.004152999104478391], [17.61296913194023], [-126.28907722179103]]
    J: 2.268219942362192  Alpha: 0.01  Theta: Matrix[[0.002959017910448543], [15.415473392238736], [-125.92224111492536]]
    J: 2.1307522353180164  Alpha: 0.01  Theta: Matrix[[0.002093389253732125], [13.751072827761122], [-125.48597339134326]]
    J: 2.027651529662123  Alpha: 0.01  Theta: Matrix[[0.0014367116417918252], [12.436814710149182], [-125.00961691402983]]
    J: 1.9589177059909308  Alpha: 0.01  Theta: Matrix[[0.0009889847761201823], [11.44908667850739], [-124.49911195194028]]
    J: 1.8558169406332465  Alpha: 0.01  Theta: Matrix[[0.0006606582089560022], [10.652638055522315], [-123.97004023522386]]
    J: 1.8214500586485458  Alpha: 0.01  Theta: Matrix[[0.0004218823880604789], [9.988664770447688], [-123.42914782925371]]
    J: 1.8214500884994413  Alpha: 0.01  Theta: Matrix[[0.0002428068653197179], [9.416182220312082], [-122.88082274064425]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00023086931308091184], [9.369775500013574], [-122.82513353589798]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00021893176084210577], [9.323368779715066], [-122.7694443311517]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0002069942086032997], [9.276962059416558], [-122.71375512640543]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00019505665636449364], [9.23055533911805], [-122.65806592165916]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00018311910412568757], [9.184148618819542], [-122.60237671691289]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0001711815518868815], [9.137741898521034], [-122.54668751216661]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00015924399964807544], [9.091335178222526], [-122.49099830742034]]
    J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00014730641755852312], [9.04492840598372], [-122.43530910670393]]
    J: 1.8677695240029366  Alpha: 0.001  Theta: Matrix[[0.0001353688354689708], [8.998521633744915], [-122.37961990598751]]
    J: 1.8462563443835032  Alpha: 0.0001  Theta: Matrix[[0.0001341750742749415], [8.993880951437452], [-122.374050986289]]
    J: 1.8247430163841476  Alpha: 0.0001  Theta: Matrix[[0.00013298131308164604], [8.98924026913124], [-122.3684820665904]]
    J: 1.803243007740144  Alpha: 0.0001  Theta: Matrix[[0.0001317875528781551], [8.984599588510665], [-122.36291314676808]]
    J: 1.7875423426167685  Alpha: 0.0001  Theta: Matrix[[0.00013059512176735966], [8.979961171334951], [-122.35734406080917]]
    J: 1.7870839229503594  Alpha: 0.0001  Theta: Matrix[[0.0001296573060241053], [8.97575636413016], [-122.35174314792931]]
    J: 1.7870831481868632  Alpha: 0.0001  Theta: Matrix[[0.00012876197468911015], [8.971623907872633], [-122.34613692449842]]
    J: 1.7870831468153818  Alpha: 0.0001  Theta: Matrix[[0.00012786672082037553], [8.967491583540149], [-122.34053069138426]]
    J: 1.7870831468129538  Alpha: 0.0001  Theta: Matrix[[0.000126971467088789], [8.963359259441226], [-122.33492445825294]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001260762133574453], [8.959226935342718], [-122.3293182251216]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012518095962610202], [8.95509461124421], [-122.32371199199025]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012428570589475874], [8.950962287145702], [-122.3181057588589]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012339045216341546], [8.946829963047193], [-122.31249952572756]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012249519843207218], [8.942697638948685], [-122.30689329259621]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012159994470072888], [8.938565314850177], [-122.30128705946487]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012070469096938559], [8.934432990751668], [-122.29568082633352]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001198094372380423], [8.93030066665316], [-122.29007459320218]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.000118914183506699], [8.926168342554652], [-122.28446836007083]]
    J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00011801892977535571], [8.922036018456144], [-122.27886212693949]]
    ......
    
    /* logloss(a,y) = log(1+exp(-a*y)) */
    double loss(double a, double y)
    {
      double z = a * y;
      if (z > 18) {
        return exp(-z);
      }
      if (z < -18) {
        return -z;
      }
      return log(1 + exp(-z));
    }
    
    /*  -dloss(a,y)/da */
    double dloss(double a, double y)
    {
      double z = a * y;
      if (z > 18) {
        return y * exp(-z);
      }
      if (z < -18){
        return y;
      }
      return y / (1 + exp(z));
    }