Matlab 如何更正网格搜索?

Matlab 如何更正网格搜索?,matlab,svm,hyperparameters,Matlab,Svm,Hyperparameters,尝试使用网格搜索为我的svm模型找到最优的超参数,但它只是为超参数返回1 function evaluations = inner_kfold_trainer(C,q,k,features_xy,labels) features_xy_flds = kdivide(features_xy, k); labels_flds = kdivide(labels, k); evaluations = zeros(k,3); for i = 1:k fprintf('Fold %i of

尝试使用网格搜索为我的svm模型找到最优的超参数,但它只是为超参数返回1

function evaluations = inner_kfold_trainer(C,q,k,features_xy,labels)

features_xy_flds = kdivide(features_xy, k);
labels_flds = kdivide(labels, k);

evaluations = zeros(k,3);

for i = 1:k

    fprintf('Fold %i of %i\n',i,k);

    train_data =  cell2mat(features_xy_flds(1:end ~= i));
    train_labels = cell2mat(labels_flds(1:end ~= i));
    test_data = cell2mat(features_xy_flds(i));
    test_labels = cell2mat(labels_flds(i));

    %AU1 
    train_labels = train_labels(:,1);
    test_labels = test_labels(:,1);


    [k,~] = size(test_labels);

    %train
    sv = fitcsvm(train_data,train_labels, 'KernelFunction','polynomial', 'PolynomialOrder',q,'BoxConstraint',C);
    sv.predict(test_data);

    %Calculate evaluative measures
    %svm_outputs = zeros(k,1);
    sv_predictions = sv.predict(test_data);


    [precision,recall,F1] = evaluation(sv_predictions,test_labels);
    evaluations(i,1) = precision;
    evaluations(i,2) = recall;
    evaluations(i,3) = F1;


end

save('eval.mat', 'evaluations');

end
内折叠交叉验证函数 在网格函数下面,似乎有什么地方出了问题

function [q,C] = grid_search(features_xy,labels,k)

% n x n grid
n = 3;

q_grid = linspace(1,19,n);
C_grid = linspace(1,59,n);

tic

evals = zeros(n,n,3);

for i = 1:n
    for j = 1:n
        fprintf('## i=%i, j=%i ##\n', i, j);
        svm_results = inner_kfold_trainer(C_grid(i), q_grid(j),k,features_xy,labels);
        evals(i,j,:) = mean(svm_results(:,:));
        % precision only
        %evals(i,j,:) = max(svm_results(:,1));

        toc
    end
end

f = evals;

% retrieving the best value of the hyper parameters, to use in the outer
% fold
[M1,I1] = max(f);
[~,I2] = max(M1(1,1,:));
index = I1(:,:,I2);
C = C_grid(index(1))
q = q_grid(index(2))

end

例如,当我运行
grid\u search(features\u xy,labels,8)
时,对于任意k(折叠数)值,我得到C=1和q=1。还有一个特点是xy是一个500*98的矩阵。

您可能需要稍微修改一下代码,以帮助我们遵循它。这里有一些多余的语句(例如
sv.predict()
),奇怪的命令(例如
(:,:)
)和不必要的复杂性(例如,我怀疑k-foldness是否是重现代码所必需的,与
ti
toc
或类似的语句相同)。因此,如果不进行这些更改,则很难遵循此代码?也许您可以提供一个最小的工作示例,其中包括a)虚拟数据和b)将代码锐化到最小行数。这将有助于提高人们阅读你的文章的可能性,并有助于我认为你的最大(f);line得到的是最大精度值,因此值1在这方面是有意义的。如果要获取超参数,则需要返回超参数值。这些可能是SV变量中的某个地方,所以我会考虑在函数结果值中包含它。