1. K近邻算法
什么是K近邻算法?简单的说就是:给定所有训练数据和一组测试数据之后,在训练数据中寻找离这组测试数据最近的K组邻居,根据这K组邻居的label来做voting/average,从而预测出测试数据的label。
如上图所示,绿色圆圈是测试数据,其余为训练数据。假设K == 3
,即寻找离测试数据最近的三个邻居,如图中黑色实线圆中的三个邻居(两个红色三角形,一个蓝色正方形)。在这三个邻居中,少数服从多数,所以测试数据就被分类为红色三角形。假设K == 5
,那虚线圆中的5个邻居就是离测试数据最近的5个邻居,里面有三个蓝色正方形和2个红色三角形,所以少数服从多数,测试数据被分类为蓝色正方形。
2. 距离度量
那么如何找到最近的K个邻居呢?换句话说,如何度量样本之间的距离?
最常用的距离度量方式自然是欧式距离,类似的还有马氏距离。本质上来说,寻找最近的邻居无非就是寻找最相似的邻居,所以余弦相似度这类相似度度量方式也完全可以在此处使用。
3. Search algorithms
接下来介绍查找最近K个邻居的算法:Brute Force, K-D tree。
3.1 Brute Force
最直观最暴力的算法,直接计算测试数据和所有训练数据的距离,对距离排序取前K个,即最近的K个邻居。这个算法缺点很明显:计算量大,效率低。对于N个D维的samples,Brute Force方法的效率是$O(DN^2)$
决定了查找算法之后我们就可以着手简单的实现k近邻算法了,此处以欧式距离和Brute Force算法距离,python实现如下:
#!/usr/bin/env python
# encoding: utf-8
import numpy as np
class KnnClassifier(object):
def __init__(self, n_neighbors=1):
self.n_neighbors = n_neighbors
def fit(self, X, y):
if X.shape[0] != y.shape[0]:
raise ValueError('train data array_shape mismatch')
self._X = X
self._y = y
def _predict(self, X):
X_train = self._X
y_train = self._y
dist_l2 = np.sum((X_train - X) ** 2, axis=1)
sorted_index = dist_l2.argsort()
y_pred = np.bincount(y_train[sorted_index[0:self.n_neighbors]]).argmax()
return y_pred
def predict(self, X):
X_train = self._X
y_train = self._y
if X_train.shape[1] != X.shape[1]:
raise ValueError('test data array_shape mismatch')
y_pred = np.empty((X.shape[0], 1), dtype=y_train.dtype)
for i in range(X.shape[0]):
y_pred[i] = self._predict(X[i])
return y_pred
3.2 k-d tree
为了解决Brute Force算法的计算效率问题,人们发明了很多基于树的数据结构。这些数据结构可以帮助我们减少要计算的距离数量。核心思想如下:如果点A距离点B很远,而点B距离点C很近,那么我们可以知道点A距离点C也会很远,因此就不必再精准地计算点A和点C的距离。运用这一点,我们可以将K近邻算法的计算效率提高到$O(DNlog(N))$甚至更好。
3.2.1 构建KD树
早期的利用这一点的算法是KD树(K-dimensional tree),本质上它就是一棵二叉树,如果你了解决策树模型的话,会发现两者构建二叉树的过程很像。在决策树模型中,构建二叉树时我们会用到训练数据的特征和labels,分支时也是根据数据的labels不纯度进行分支;然而在这里构建KD树时我们只用到了训练数据的特征,并没有用到训练数据的labels,那么我们如何进行分支呢?
对于所有训练数据,计算它们在每个特征维度上的数据方差,取方差最大的那一维作为分支参考,方差最大意味着在该维度上数据最分散,在这个维度上进行分支有较好的分辨率。
确定了分支参考维之后,接下来按照分支参考维对所有数据排序,取中间的数据作为分支节点,在分支参考维上小于分支节点的划入左子树,其余的划入右子树。
然后左右子树分别递归构建KD树,直到不能分支为止。整个过程伪代码如下:
def CreateKDTree(data):
if data.empty()
return
leftData, rightData = split(data)
CreateKDTree(leftData)
CreateKDTree(rightData)
下图是一个KD树划分二维数据的实例:
3.2.2 KD树最近邻搜索
构建好KD树后,我们看看如何在KD树上查找目标点t的最近邻居。
- 进行二叉查找,取叶子节点作为当前最近邻点。如果以目标点t为圆心,目标点t与当前最近邻点的距离为半径的圆没有超过切割超平面,则查找结束,如下图。否则进入第二步。
- 回溯查找。进入第一步中与圆相交的空间查找,看是否存在比当前最近邻点更近的点,若存在则更新为最近邻点,继续按照步骤1递归式查找,如下图: