首页 > 模式识别 > 性别识别专题二——Fisher分类器

性别识别专题二——Fisher分类器

2013年5月22日 发表评论 阅读评论

好吧,我又来填坑了。。。为啥呢?因为这周我的任务好像差不多了,硬件客观条件所限,要干活最快也要到周六,想象自己给自己挖了这么多坑,补天于心不安啊!!

说到坑这个东西,大家可能有所不知,在这个BLOG里,外表看上去我挖的坑不多(但也不是没有),但是殊不知我的draft里面躺了一大堆草稿,觉得这个可以写来玩玩,那个又想写来分享什么的,而且有些想法发现像个无底洞,所以就被永久雪藏了,比如之前搞的那个统计信号处理专题,出了1之后就在也没动过了。。。

填坑,是个良心活!!

闲话就说这么多吧。。。


线性判别式分析(Linear Discriminant Analysis, LDA),也叫做Fisher线性判别,也就是数模里面常见的Fisher分类器,之前讲到的那个PCA,经常和LDA一起出现,解决了无数的问题,看成模式识别界最平民化的模范COUPLE,平民化一是因为用的多,二是简单,粗暴,易懂。。

这个分类器的思想是什么呢?其实很简单,其实我是想自己做个图的,但是懒,所以就上网找了个,下图如若侵权,那我先在此道歉了。。。我错了(低头)。。。fisherexample
比如说上图,我们有两个类,他们是二维的点,Fisher分类器是想怎么做呢?它想找一条直线(一般而言是一个比原数据维数更低维的空间,如面到点),然后把源数据投影到这个线上,使得数据在这条低维空间上具有很好的“可区分性”,像上图中画得最佳投影方向,如果投影的是上图中的最不利的投影方向的话,那么数据在线上就基本无法区分了,对吧。

Fisher分类器就是根据你输入的数据,算出这一条“线”来,识别的时候就是往线上投影,根据阈值来判断属于哪一类。

那么Fisher分类器是怎么算出那条线来的呢?下面是简要说明,详细复杂的计算证明大家自己找文献去。。。

为了方便说明,我们用二维例子来说,也就是上面的二维的点,我们可以用一个1×2的向量w来表示一条线,这大家没异议吧。

Fisher的思想就是它定义了一个类内散度矩阵和类间散度矩阵这两个东西,顾名思义,就是投影之后每个类是不是很“聚拢”,类与类之间是不是尽量的远,首先我们可以算出每个类的均值,假设有两个类,均值分别是上图中的μ1,μ2,然后还有一个所有样本的均值μ,然后把这三个均值也投影到w这条线上,假设投影后分别是μ1,μ2,μ;

然后呢,类间散度SB就是:

N11)2+N22)2

其中两个N就是每一类的数目的个数,然后下面给出一个结论,不给证明,就是类间散度还可以等价表示成下面的形式:

SB = wTSBw

其中SB是: 

12)(μ12)T

上面证明略。。。因为博客敲公式麻烦。。不是Latex。。。(摊手~)

接下来就是类内散度了,下面语言描述,注意断句,每个类的类内散度的值就等于这个类里面(每个数据投影后和类均值投影后的差的平方)的和,懂了?

然后我们要求的总的类散度就是每个类的类散度的和。

然后下面还是不加证明的给出类内散度SW的化简形式:

SW = wTSWw

其中,SW是:

∑∑(xpji)(xpji)T

两个求和第一重是对所有类求和,第二重是对每个类内部每个元素求和,xpj表示的是第p类的第j个点在w上的投影;

有了这两个定义,下面是目标,就是要让类间散度除以类内散度,这个值J达到最大!!

下面干一件比较没节操的事情,就是,这个最大值怎么求呢?我直接告诉你,至于怎么出来的,查看一下文献,很容易就明白了,为什么这么干呢?还是还是那个原因,结果形式简单明了,中间过程敲公式太复杂,如果哪一天这里支持Latex语法的话,那我就谢天谢地了。。。

让这个J达到最大的w的形式是:

(Sw)-112)

嗯,你没看错,就是那么简单,正是因为结果那么简单,所以我上面才不想说那么多废话啊。。。

投影之后的判断的阈值一般可以用(μ12)Tw/2或者(N1μ1+N2μ2)Tw/2来算。


嘛,算法描述就到这里,下面照旧,简单的实验;

也是为了简单起见,我这个实验,简单,粗暴。。。

我直接画了一条直线,然后做出两个类,如下图所示:

fisher这些点是二维的点,投影后变成一维的一个数,这些数分别如下图所分布:
fisherresult红线是阈值对应的线,可以看出,目的基本达成。

好吧,下面是代码。。。不懂得孩子再自己研究研究吧。。。


code:

data = createdata();

opp = data(data(:,3)==1,1:2);
neg = data(data(:,3)==-1,1:2);

plot(opp(:,1),opp(:,2),'r.');
hold on;
plot(neg(:,1),neg(:,2),'b.');

mean_opp = mean(opp);
mean_neg = mean(neg);

diff_opp = opp-repmat(mean_opp,size(opp,1),1);
S1 = diff_opp'*diff_opp;
diff_neg = neg-repmat(mean_neg,size(neg,1),1);
S2 = diff_neg'*diff_neg;
Sw = S1+S2;

w = (mean_opp-mean_neg)*Sw^-1;
thred = (mean_opp+mean_neg)/2*w';

%投影后的结果
opp_new = opp*w';
neg_new = neg*w';

figure;
plot(opp_new,'g.');
hold on;
plot(neg_new,'b.');
plot([1 5000],[thred thred],'r');

function data = createdata()
data = [];
N = 10000;
x = 10*rand(1,N);
y = 10*rand(1,N);
for i = 1 : N
    if(2.5+0.5*x(i)-y(i)>0)
        data = [data;x(i),y(i),1];
    else
        data = [data;x(i),y(i),-1];
    end
end

【完】

本文内容遵从CC版权协议,转载请注明出自http://www.kylen314.com

  1. 本文目前尚无任何评论.
验证码:0 + 3 = ?

友情提示:留言可以使用大部分html标签和属性;

添加代码示例:[code lang="cpp"]your code...[/code]

添加公式请用Latex代码,前后分别添加两个$$