声振论坛

 找回密码
 我要加入

QQ登录

只需一步,快速开始

查看: 6142|回复: 7

[人工智能] 利用SVM实现一个三类分类问题(转贴)

[复制链接]
发表于 2008-9-21 07:47 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?我要加入

x
一.任务要求
用SVM求解一个三类分类问题,实验数据为“鸢尾属植物数据集”,核函数为径向基核函数(RBF),误差评测标准为K折交叉确认误差。

二.实验方案
1. 用quadprog函数实现C-SVC来进行分类
——quadprog是matlab中一个求解二次规划的函数,通过适当的参数设置,可以利用quadprog函数实现C-SVC
2. 用matlab自带的SVM工具包来实现分类
——matlab2006版本中集成了SVM工具包,可以通过调用工具包中的svmtrain和svmclassify函数来进行训练和分类
3. 三类问题的分类方法
——将三类问题转化为三个两类问题,分别求出相应的决策函数即可(优点:方法简单易行;缺点:容易形成死区)

三.实验程序
1. 用Quadprog实现
  1. clear all
  2. % Load the data and select features for classification
  3. load fisheriris;
  4. data = meas;
  5. %Get the size of the data
  6. N = size(data,1);
  7. % Extract the Setosa class
  8. groups_temp = ismember(species,'versicolor');%versicolor,virginica,setosa
  9. %convert the group to 1 & -1
  10. groups = 2*groups_temp - ones(N,1);

  11. indices = crossvalind('Kfold', groups);

  12. ErrorMin = 1;
  13. for r=1:1:5
  14.     for C=1:1:5
  15.         ErrorNum = 0;        
  16.         for i=1:5
  17.             %Use K-fold to get train data and test data
  18.             test = (indices == i); train = ~test;
  19.             
  20.             traindata = data(train,:);
  21.             traingroup = groups(train,:);
  22.             trainlength = length(traingroup);
  23.             
  24.             testdata = data(test,:);
  25.             testgroup = groups(test,:);
  26.             testlength = length(testgroup);
  27.             
  28.             %Get matrix H of the problem
  29.             kfun = [];
  30.             for i=1:1:trainlength
  31.                 for j=1:1:trainlength
  32.                     %rbf kernel
  33.                     kfun(i,j)=exp(-1/(r^2)*(traindata(i,:)-traindata(j,:))*(traindata(i,:)-traindata(j,:))');
  34.                 end
  35.             end

  36.             %count parameters of quadprog function
  37.             H = (traingroup*traingroup').*kfun;
  38.             xstart = zeros(trainlength,1);
  39.             f = -ones(trainlength,1);
  40.             Aeq = traingroup';
  41.             beq = 0;
  42.             lb = zeros(trainlength,1);
  43.             ub = C*ones(trainlength,1);
  44.             
  45.             [alpha,fval] = quadprog(H,f,[],[],Aeq,beq,lb,ub,xstart);
  46.             
  47.             %Get one of the non-zero part of vector alpha to count b
  48.             j = 1;
  49.             for i=1:size(alpha)
  50.                 if(alpha(i)>(1e-5))
  51.                     SvmClass_temper(j,:) = traingroup(i);
  52.                     SvmAlpha_temper(j,:) = alpha(i);
  53.                     SvmVector_temper(j,:)= traindata(i,:);
  54.                     j = j + 1;
  55.                     tag = i;
  56.                 end
  57.             end
  58.             
  59.             b=traingroup(tag)-(alpha.*traingroup)'*kfun(:,tag);
  60.             
  61.             %Use the function to test the test data
  62.             kk = [];
  63.             for i=1:testlength
  64.                 for j=1:trainlength
  65.                     kk(i,j)=exp(-1/(r^2)*(testdata(i,:)-traindata(j,:))*(testdata(i,:)-traindata(j,:))');
  66.                 end
  67.             end

  68.             %then count the function
  69.             f=(alpha.*traingroup)'*kk' + b;           
  70.             for i=1:length(f)
  71.                 if(f(i)>(1e-5))
  72.                     f(i)=1;
  73.                 else
  74.                     f(i)=-1;
  75.                 end
  76.             end         
  77.             
  78.             for i=1:length(f)
  79.                 if(testgroup(i)~=f(i))
  80.                     ErrorNum = ErrorNum + 1;
  81.                 end
  82.             end         
  83.         end
  84.         
  85.         ErrorRate = ErrorNum / N;
  86.         
  87.         if(ErrorRate<ErrorMin)
  88.             SvmClass = SvmClass_temper;
  89.             SvmAlpha = SvmAlpha_temper;
  90.             SvmVector = SvmVector_temper;
  91.             ErrorMin = ErrorRate;
  92.             CorrectRate = 1 - ErrorRate;
  93.             Coptimal = C;
  94.             Roptimal = r;
  95.         end
  96.         
  97.     end
  98. end            
复制代码

评分

1

查看全部评分

回复
分享到:

使用道具 举报

 楼主| 发表于 2008-9-21 07:47 | 显示全部楼层
2. 用SVM工具包实现
  1. clear all
  2. % Load the data and select features for classification
  3. load fisheriris
  4. % data = [meas(:,3),meas(:,4)];
  5. data=meas;
  6. % Extract the Setosa class
  7. groups = ismember(species,'versicolor');%versicolor,virginica,setosa
  8. % Randomly select training and test sets
  9. index = crossvalind('Kfold',groups);
  10. cp = classperf(groups);

  11. fr=0;
  12. fc=0;
  13. fcorrect=0;
  14. correct5=0;

  15. for r=1:1:10
  16.     for c=1:1:100
  17.         for i=1:5
  18.             test = (index == i); train = ~test;
  19.             % Use a RBF support vector machine classifier
  20.             %         svmStruct = svmtrain(data(train,:),groups(train),'KERNEL_FUNCTION','rbf','kfunargs',5,'boxconstraint',1000,'showplot',true);
  21.             %         classes = svmclassify(svmStruct,data(test,:),'showplot',true);
  22.             svmStruct = svmtrain(data(train,:),groups(train),'KERNEL_FUNCTION','rbf','kfunargs',1/(r^2),'boxconstraint',c);
  23.             classes = svmclassify(svmStruct,data(test,:));
  24.             % See how well the classifier performed
  25.             classperf(cp,classes,test);
  26.             %             cp.CorrectRate
  27.             correct5=correct5+cp.CorrectRate/5;
  28.         end
  29.         r
  30.         c
  31.         correct5
  32.         if(fcorrect<correct5)
  33.             fcorrect=correct5
  34.             fr=r
  35.             fc=c
  36.         end
  37.         correct5=0;
  38.     end
  39. end
复制代码
 楼主| 发表于 2008-9-21 07:47 | 显示全部楼层
四.实验结果
1. Quadprog实现
(1)类别:versicolor 参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =0.9696 Roptimal =1   Coptimal =2
(2)类别:virginica 参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =0.9430 Roptimal =1   Coptimal =2
(3)类别:setosa    参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =1 Roptimal =1   Coptimal =1
2. SVM工具包实现
(1)类别:versicolor 参数:r(1-5) C(1-50)
      运行结果:
CorrectRate =1 Roptimal =2   Coptimal =22
(2)类别:virginica  参数:r(1-5) C(1-50)
      运行结果:
CorrectRate =0.9867 Roptimal =10   Coptimal =44
(3)类别:setosa    参数:r(1-10) C(1-100)
      运行结果:
CorrectRate =1 Roptimal =1   Coptimal =1
发表于 2008-9-26 12:21 | 显示全部楼层
无法运行,indices = crossvalind('Kfold', groups);出错,请求帮助!
发表于 2009-11-11 21:48 | 显示全部楼层
crossvalind  这是个什么函数啊?那个库里有呢?
发表于 2009-12-7 20:20 | 显示全部楼层

怎么了

Error: The input character is not valid in MATLAB statements or expressions.
发表于 2012-7-16 09:02 | 显示全部楼层
如果只有两种类型的训练数据,能不能分出三类,也就是除了分出这两点之外,能不能把阈值之外的算作第三类,第三类是虚拟的,三是没有特征数据
发表于 2012-8-2 23:59 | 显示全部楼层
特征数据
您需要登录后才可以回帖 登录 | 我要加入

本版积分规则

QQ|小黑屋|Archiver|手机版|联系我们|声振论坛

GMT+8, 2024-11-17 09:28 , Processed in 0.060847 second(s), 19 queries , Gzip On.

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表