博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
分类模型的评价方法
阅读量:5942 次
发布时间:2019-06-19

本文共 5357 字,大约阅读时间需要 17 分钟。

机器学习中对于分类模型常用混淆矩阵来进行效果评价,混淆矩阵中存在多个评价指标,这些评价指标可以从不同角度来评价分类结果的优劣,以下内容通过简单的理论概述和案例展示来详细解释分类模型中混淆矩阵的评价指标及其用途。

1、混淆矩阵的概念

2、衍生评价指标
3、ROC曲线、AUC指标
4、R&Python中混淆矩阵函数

1、混淆矩阵的基本概念

对于分类模型而言(这里仅以最简单的二分类为例,假设只有0和1两类),最终的判别结果无非就四种情况:

实际为0被正确预测为0,实际为0被错误预测为1,实际为1被错误误测为0,实际为1被正确预测为1。

ca08546e502214c3e58ab6453229296e331cc19d

以上四类判别结果展示在混淆矩阵上是一个两行两列的交叉矩阵,行分别代表实际的正例和负例,列分别代表预测的正例和负例。

那么在以上矩阵中:四个象限分别代表四种判别结果:

左上角被称为真阳性(True Positive,TP):样本实际为正(这里的正负仅仅是相对意义上我们想要研究的类别)例,且模型预测结果为正例;

右上角被称为假阴性(False Negative,FN):样本实际为正例,但模型预测为负例;
左下角被称为假阳性(False Positive,FP):样本实际类别为负例,但模型预测为正例;
右下角被称为真阴性(True Negative,TN):样本实际类别为负例,且模型预测为负例。

混淆矩阵的四个象限有明显的规律,左上角至右下角的对角线上是预测正确(以T开头),另一条对角线则预测错误(以F开头),左侧上下象限是预测为真的类别(以P结尾),右侧上下象限为预测错误的类别(以N结尾)。

这样真个混淆矩阵看起来就清洗多了,围绕着混淆矩阵有几个比较重要的指标需要掌握。

2、评价指标:

2.1 分类准确率(即所有分类中被正确分类的比例,也称识别率)

(TP + TN)/(TP + TN + FN + FN)

2.2 召回率-Recall(也称灵敏率、真正例识别率)

召回率的含义是指:正确识别的正例个数在实际为正例的样本数中的占比

Recall = TP/(TP + FN)

2.3 精确率

精确率的含义是指:预测为真的正样本占所有预测为正样本的比例。

Precision = TP/(TP + FP)

2.4 F度量(F1分数或者F分数)

F度量是是基于以上度量(精确率和召回率)衍生的计算指标,具体计算公式如下:

F度量 = 2PrecisionRecall/(Precision + Recall)

3、ROC曲线、AUC指标

ROC的全名叫做Receiver Operating Characteristic,主要通过平面坐标系上的曲线来衡量分类模型结果好坏——ROC curve。

eb8ac461706937a0389030c5c9e2619254cf95f7

横坐标是false positive rate(FPR),
纵坐标是true positive rate(TPR)。

以上纵坐标的TPR即是上述的指标召回率,FPR则指代负样本中的错判率(假警报率),FPR = FP/(FP + TN) 。

典型的ROC曲线是一个位于坐标点(0,0)和(1,1)对角线上方的曲线,因为对角线代表着随机分类器的分类效果。

ROC曲线只能通过图形来进行视觉判别,取法具体量化分类器的性能,于是AUC便出现了,它用来表示ROC曲线下的三角形面积大小,通常,AUC的值介于0.5到1.0之间,较大的AUC代表了较好的performance。

4、R&Python中的混淆矩阵及指标计算

4.1 R语言中的混淆矩阵

这里使用iris数据集来实现简单的knn分类,并使用R中的混淆矩阵来对其进行性能解读。

 
library("magrittr")
library("dplyr")
library("class")
library("caret")
library("scales")
library("gmodels")

为了方便演示二分类的混淆矩阵结果,这里我删掉一类,并将字符型的类别进行数字编码。

 
#处理分类编码
data(iris)
iris$Species <- as.character(iris$Species)
iris <- iris %>% filter(Species != "setosa")
iris$Species <- factor(iris$Species)
#特征标准化
iris_data <- iris
iris_data[,1:4] = apply(iris_data[,1:4],2,rescale,to = c(0,1))
#划分数据集
split1 <- createDataPartition(y=iris_data$Species,p=0.7,list = FALSE)
train_data <- iris_data[split1,1:4]
train_label<- iris_data[split1,5]
test_data <- iris_data[-split1,1:4]
test_label <- iris_data[-split1,5]
#训练模型并输出预测值
test_pre_labels <- knn(train_data,test_data,train_label,k =5,prob=TRUE)
#混淆矩阵输出:
confusionMatrix(test_label,test_pre_labels,dnn = c("Prediction","Actutal"))
table(test_label,test_pre_labels,dnn = c("Actutal","Prediction"))

caret包中的confusionMatrix函数可以非常快速的输出分类器分类结果的混淆矩阵。混淆矩阵中除了输出判别 矩阵之外,还给出了常用的判别指标。

 
TP = 12
FN = 0
FP = 3
TN = 15
Accuracy = (TN + TP)/(TN+TP+FN+FP)
(12+15)/(12+3+0+15) = 0.9
Recall = TP/(TP + FN) #对应混洗矩阵输出中的Sensitivity指标,也称灵敏性
12/(12+0) = 1
Precision = TP/(TP + FP) #对应混洗矩阵输出中的Pos Pred Value
12/(12+3) = 0.8
F1 = 2*Recall*Precision/(Recall + Precisio)
2*1*0.8/(1+0.8) = 0.8888889
library("pROC")
测数据
date_roc <- roc(test_label,ordered(test_pre_labels))
plot(date_roc, print.auc = TRUE, auc.polygon = TRUE, legacy.axes = TRUE,
grid = c(0.1, 0.2), grid.col = c("green", "red"), max.auc.polygon = TRUE,
auc.polygon.col = "skyblue", print.thres = TRUE, xlab = "特异度", ylab = "灵敏度",
main = "ROC曲线")

可以从ROC曲线图表输出上看到以上KNN分类结果的AUC值为0.9

35e3bb4f3049999c56049e018862889fc1d519ba

5.2 Python中的混淆矩阵与衍生指标计算

 
from sklearn.preprocessing import Imputer,LabelEncoder,OneHotEncoder
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
from sklearn import neighbors
from sklearn import metrics
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
#导入数据
iris = load_iris()
data = iris['data']
iris_data = pd.DataFrame(
data = data,
columns = ['sepal_length','sepal_width','petal_length','petal_width']
)
iris_data["Species"] = iris[ 'target']
iris_data = iris_data.loc[iris_data['Species'] != 0,:]
#数据集分割
x,y = iris_data.iloc[:,0:-1],iris_data.iloc[:,-1]
train_data,test_data,train_target,test_target = train_test_split(x,y,test_size = 0.3,stratify = y)
#特征标准化
min_max_scaler = preprocessing.MinMaxScaler()
#实例化0-1标准化方法
X_train = min_max_scaler.fit_transform(train_data.values)
X_test = min_max_scaler.transform(test_data.values)
#模型拟合
model_KNN = neighbors.KNeighborsClassifier()
model_KNN.fit(X_train,train_target)
#预测结果
Pre_label = model_KNN.predict(X_test)
#指标计算
1、混淆矩阵输出
metrics.confusion_matrix(test_target,Pre_label)
TP = 14
FN = 1
FP = 2
TN = 13
2、分类准确率计算
metrics.accuracy_score(test_target,Pre_label)
Accuracy = (TP + TN)/(TP + TN + FN + FP)
(14+13)/(14+1+2+13) = 0.9
3、召回率(Recall、或称灵敏性-Sensitivity)
metrics.recall_score(test_target,Pre_label)
Recall = TP/(TP + FN)
14/(14+1) = 0.9333333333333333
4、精确度(Precision)
metrics.precision_score(test_target,Pre_label)
Precision = TP/(TP + FP)
14/(14 + 2) = 0.875
5、F1度量
metrics.f1_score(test_target,Pre_label)
(2*Precision*Recall) / (Precision+Recall) = 0.9032258064516129
6、ROC曲线与AUC值
fpr,tpr,thresholds = metrics.roc_curve(np.array(test_target),Pre_label,pos_label=2)
plt.plot(fpr, tpr)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.title('ROC curve for diabetes classifier')
plt.xlabel('False Positive Rate (1 - Specificity)')
plt.ylabel('True Positive Rate (Sensitivity)')
plt.grid(True)
metrics.roc_auc_score(test_target.values -1,Pre_label)
0.9
99d26ee6f5d3770fced15ba7b8d5063ac6d30a46

AUC指标用来评估分类器性能,可以兼顾样本中类别不平衡的情况,这一点上要比分类准确率更加具有参考价值;
整体而言,混淆矩阵给我们呈现了一个清晰可见的分类模型效果评估工具,而基于混淆矩阵的评估指标可以从不同侧面来评价分类器性性能,至于在实际操作中使用什么样的评估指标来进行评价,还要视具体的分析目标而定。

比如在文档检索方面,如果想要尽可能的提高检索到的文档中实际有价值的文档,就应该着手提高精确度,否则会面临大量冗余信息;在右键拦截领域,为了防止误伤重要右键,则需要适当提高召回率(查全率),否则会导致重要信息被遗漏。

原文发布时间为:2018-11-11

本文作者:杜雨

本文来自云栖社区合作伙伴“”,了解相关信息可以关注“”。

转载地址:http://wczxx.baihongyu.com/

你可能感兴趣的文章
substring
查看>>
Java抽象类和接口的区别(好长时间没看这种文章了)
查看>>
markdown语法
查看>>
oracle11g dataguard 完全手册
查看>>
关系模式数据库设计范式深入浅出
查看>>
打油诗 现代教育经济学
查看>>
隐马尔科夫模型(Hidden Markov Models) 系列之二
查看>>
OpenXml操作Word的一些操作总结.无word组件生成word.
查看>>
WPF模板
查看>>
java.lang.ClassCastException: sun.proxy.$Proxy11 cannot be cast to分析
查看>>
加载ConversationListActivity以及延迟的使用
查看>>
Extjs4.2 Grid搜索Ext.ux.grid.feature.Searching的使用
查看>>
GTK、KDE、Gnome、XWindows 图形界面
查看>>
hdu1231-最大连续子序列
查看>>
TMG阵列部署选择
查看>>
Repeater 控件 当数据源没有数据的时候显示 暂无数据 的两种方式
查看>>
大型网站的架构设计图分享-转
查看>>
Lambda应用设计模式
查看>>
const成员函数
查看>>
9.15游戏化体验的原则初探
查看>>