《机器学习实战》—K-近邻算法

基本原理:通过计算新数据与给定的样本数据之间的距离,来确定相似度排名;然后取前K个最相似的样本,统计这k(一般不大于20)个样本中出现最多的分类,设为新数据的分类。

关键词:新数据,训练样本集,样本数据标签(即分类),最近邻(前k个最相似数据),最近邻标签

算法实施:

首先提取要比较的特征,确定各特征的权重,进行“归一化”,自动将数字特征转化为0-1之间——autoNorm()函数。

其次准备好测试分类器。通常选用已有数据90%作为训练样本来训练分类器,剩下10%去测试分类器。即已知测试数据的分类,再去拿分类器验证,求出错误率。该函数应该自包含(可以在任何时候使用来测试分类效果)——classTest()函数。

然后算法使用,使用欧氏距离公式,计算两个向量点的距离。该向量就是数据中各个特征值归一化之后的数值组成的一维矩阵,可以用numpy进行运算;随后将距离按照从小到大排序;再确定前k个距离最小元素坐在的主要分类,返回发生频率最高的元素标签——classify()函数。

KNN算法:

简单案例,包含全部流程

#coding: utf-8
from numpy import *
import operator
def createDataSet():
	# 用户自封闭测试的数据
	group = array([[1.0,1.1],[1.0,1.0],[0.0,0.0],[0,0.1],[5,10],[10,5]])
	labels = ['A','A','B','B','C','C']
	return group,labels
def createDataTest():
	tGroup = array([[1.1,1.0],[1.1,1.2],[0.1,0.0],[0.1,0.1],[5,9],[9.5,5]])
	tLabels = ['A','A','B','B','C','C']
	return tGroup,tLabels
def autoNorm(dataSet):		
	# 按行处理,归一化							
	minX = dataSet.min(0)
	maxX = dataSet.max(0)
	ranges = maxX - minX
	normDataSet = zeros(shape(dataSet))
	m = dataSet.shape[0]
	normDataSet = (dataSet - tile(minX, (m,1)))/tile(ranges, (m,1))
	return normDataSet
def classify(inX, dataSet, labels, k):
	#用于输入的向量inX,输入训练样本集dataset,标签响亮labels,邻居数k
	dataSetSize = dataSet.shape[0]						# shape函数获取第一维度的长度
	sqDiffMat = (tile(inX, (dataSetSize, 1)) - dataSet)**2	# 取差的平方
	distances = (sqDiffMat.sum(axis = 1)) ** 0.5			# inX与不同数据之间距离
	sortD = distances.argsort()								# 排序,返回索引值
	classCount = {}
	for i in range(k):
		voteLabel = labels[sortD[i]]
		classCount[voteLabel] = classCount.get(voteLabel,0)+1	#计数加1
	classSort = sorted(classCount.items(), \
		key = operator.itemgetter(1), reverse=True)			# 标签统计数从大到小排
	return classSort[0][0]
def classTest(dataSet,labels,tDataSet, tLabels):
	errorCount = 0.0
	m = dataSet.shape[0]
	for i in range(m):
		result = classify(tDataSet[i], dataSet,labels,6)
		if(result != tLabels[i]): errorCount += 1
	print("the total error rate is: %f"%(errorCount/float(m)))

group, labels = createDataSet()
# group = autoNorm(group)
label = classify([0,0], group, labels, 3)
print(label)
# 错误率测试
tGroup, tLabels = createDataTest()
# tGroup = autoNorm(tGroup)
classTest(group,labels,tGroup, tLabels)

发布了397 篇原创文章 · 获赞 541 · 访问量 255万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 编程工作室 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览