frogfish 发表于 2008-9-21 07:47

利用SVM实现一个三类分类问题(转贴)

一.任务要求
用SVM求解一个三类分类问题,实验数据为“鸢尾属植物数据集”,核函数为径向基核函数(RBF),误差评测标准为K折交叉确认误差。

二.实验方案
1. 用quadprog函数实现C-SVC来进行分类
——quadprog是matlab中一个求解二次规划的函数,通过适当的参数设置,可以利用quadprog函数实现C-SVC
2. 用matlab自带的SVM工具包来实现分类
——matlab2006版本中集成了SVM工具包,可以通过调用工具包中的svmtrain和svmclassify函数来进行训练和分类
3. 三类问题的分类方法
——将三类问题转化为三个两类问题,分别求出相应的决策函数即可(优点:方法简单易行;缺点:容易形成死区)

三.实验程序
1. 用Quadprog实现
clear all
% Load the data and select features for classification
load fisheriris;
data = meas;
%Get the size of the data
N = size(data,1);
% Extract the Setosa class
groups_temp = ismember(species,'versicolor');%versicolor,virginica,setosa
%convert the group to 1 & -1
groups = 2*groups_temp - ones(N,1);

indices = crossvalind('Kfold', groups);

ErrorMin = 1;
for r=1:1:5
    for C=1:1:5
      ErrorNum = 0;      
      for i=1:5
            %Use K-fold to get train data and test data
            test = (indices == i); train = ~test;
            
            traindata = data(train,:);
            traingroup = groups(train,:);
            trainlength = length(traingroup);
            
            testdata = data(test,:);
            testgroup = groups(test,:);
            testlength = length(testgroup);
            
            %Get matrix H of the problem
            kfun = [];
            for i=1:1:trainlength
                for j=1:1:trainlength
                  %rbf kernel
                  kfun(i,j)=exp(-1/(r^2)*(traindata(i,:)-traindata(j,:))*(traindata(i,:)-traindata(j,:))');
                end
            end

            %count parameters of quadprog function
            H = (traingroup*traingroup').*kfun;
            xstart = zeros(trainlength,1);
            f = -ones(trainlength,1);
            Aeq = traingroup';
            beq = 0;
            lb = zeros(trainlength,1);
            ub = C*ones(trainlength,1);
            
             = quadprog(H,f,[],[],Aeq,beq,lb,ub,xstart);
            
            %Get one of the non-zero part of vector alpha to count b
            j = 1;
            for i=1:size(alpha)
                if(alpha(i)>(1e-5))
                  SvmClass_temper(j,:) = traingroup(i);
                  SvmAlpha_temper(j,:) = alpha(i);
                  SvmVector_temper(j,:)= traindata(i,:);
                  j = j + 1;
                  tag = i;
                end
            end
            
            b=traingroup(tag)-(alpha.*traingroup)'*kfun(:,tag);
            
            %Use the function to test the test data
            kk = [];
            for i=1:testlength
                for j=1:trainlength
                  kk(i,j)=exp(-1/(r^2)*(testdata(i,:)-traindata(j,:))*(testdata(i,:)-traindata(j,:))');
                end
            end

            %then count the function
            f=(alpha.*traingroup)'*kk' + b;         
            for i=1:length(f)
                if(f(i)>(1e-5))
                  f(i)=1;
                else
                  f(i)=-1;
                end
            end         
            
            for i=1:length(f)
                if(testgroup(i)~=f(i))
                  ErrorNum = ErrorNum + 1;
                end
            end         
      end
      
      ErrorRate = ErrorNum / N;
      
      if(ErrorRate<ErrorMin)
            SvmClass = SvmClass_temper;
            SvmAlpha = SvmAlpha_temper;
            SvmVector = SvmVector_temper;
            ErrorMin = ErrorRate;
            CorrectRate = 1 - ErrorRate;
            Coptimal = C;
            Roptimal = r;
      end
      
    end
end            

frogfish 发表于 2008-9-21 07:47

2. 用SVM工具包实现
clear all
% Load the data and select features for classification
load fisheriris
% data = ;
data=meas;
% Extract the Setosa class
groups = ismember(species,'versicolor');%versicolor,virginica,setosa
% Randomly select training and test sets
index = crossvalind('Kfold',groups);
cp = classperf(groups);

fr=0;
fc=0;
fcorrect=0;
correct5=0;

for r=1:1:10
    for c=1:1:100
      for i=1:5
            test = (index == i); train = ~test;
            % Use a RBF support vector machine classifier
            %         svmStruct = svmtrain(data(train,:),groups(train),'KERNEL_FUNCTION','rbf','kfunargs',5,'boxconstraint',1000,'showplot',true);
            %         classes = svmclassify(svmStruct,data(test,:),'showplot',true);
            svmStruct = svmtrain(data(train,:),groups(train),'KERNEL_FUNCTION','rbf','kfunargs',1/(r^2),'boxconstraint',c);
            classes = svmclassify(svmStruct,data(test,:));
            % See how well the classifier performed
            classperf(cp,classes,test);
            %             cp.CorrectRate
            correct5=correct5+cp.CorrectRate/5;
      end
      r
      c
      correct5
      if(fcorrect<correct5)
            fcorrect=correct5
            fr=r
            fc=c
      end
      correct5=0;
    end
end

frogfish 发表于 2008-9-21 07:47

四.实验结果
1. Quadprog实现
(1)类别:versicolor 参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =0.9696 Roptimal =1   Coptimal =2
(2)类别:virginica 参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =0.9430 Roptimal =1   Coptimal =2
(3)类别:setosa    参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =1 Roptimal =1   Coptimal =1
2. SVM工具包实现
(1)类别:versicolor 参数:r(1-5) C(1-50)
      运行结果:
CorrectRate =1 Roptimal =2   Coptimal =22
(2)类别:virginica参数:r(1-5) C(1-50)
      运行结果:
CorrectRate =0.9867 Roptimal =10   Coptimal =44
(3)类别:setosa    参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =1 Roptimal =1   Coptimal =1

navyman 发表于 2008-9-26 12:21

无法运行,indices = crossvalind('Kfold', groups);出错,请求帮助!

xiaokang 发表于 2009-11-11 21:48

crossvalind这是个什么函数啊?那个库里有呢?

天下大平 发表于 2009-12-7 20:20

怎么了

Error: The input character is not valid in MATLAB statements or expressions.

junjunyeti 发表于 2012-7-16 09:02

如果只有两种类型的训练数据,能不能分出三类,也就是除了分出这两点之外,能不能把阈值之外的算作第三类,第三类是虚拟的,三是没有特征数据

晨风无影 发表于 2012-8-2 23:59

特征数据
页: [1]
查看完整版本: 利用SVM实现一个三类分类问题(转贴)