KNN手写数字识别
本文将介绍 K-Nearest Neighbors(KNN)算法的基本原理,并通过一个实际案例展示如何使用 Python 和 OpenCV 来实现手写数字识别。从数据准备、模型训练到预测评估,全面解析 KNN 在机器学习中的应用。
K-Nearest Neighbors(KNN)算法简介
K-Nearest Neighbors(KNN)算法是一种基本的分类与回归方法。它的工作原理非常简单直观:通过测量不同特征值之间的距离来进行预测。KNN 算法不考虑数据的分布,它只是简单地根据已标记的数据集中最接近的 K 个数据点的类别,通过投票的方式来预测新数据点的类别。
算法原理
KNN 算法的核心思想是相似性原则,即相似的事物应该有相似的标签。在分类问题中,给定一个待分类的样本,KNN 算法会:
- 计算待分类样本与所有已知类别样本之间的距离(常用的距离度量包括欧氏距离、曼哈顿距离等)。
- 按照距离的远近对样本进行排序。
- 选取距离最近的 K 个样本(K 是一个正整数,通常由交叉验证来选择最佳值)。
- 根据这 K 个样本的已知类别,通过投票机制来决定待分类样本的类别。
在回归问题中,KNN 算法会根据最近的 K 个邻居样本的数值,计算待预测样本的预测值,通常采用平均值或加权平均值。
算法步骤
- 选择参数 K:K 值的选择对 KNN 算法的性能有很大影响。较小的 K 值意味着模型对噪声更敏感,而较大的 K 值则可能导致模型对数据的局部结构不够敏感。
- 距离度量:选择合适的距离度量方法来计算样本之间的距离。最常用的是欧氏距离,但在某些情况下,其他距离度量(如曼哈顿距离、余弦相似度等)可能更合适。
- 寻找最近的 K 个邻居:对于每个待分类的样本,找到训练集中与其距离最近的 K 个样本。
- 决策规则:对于分类问题,采用多数投票法来确定样本的类别;对于回归问题,计算 K 个邻居的平均值作为预测值。
优缺点
优点:
- 简单易懂,实现容易。
- 无需训练数据,对数据分布没有假设。
- 适合于多分类问题。
缺点:
- 计算成本高,尤其是在大数据集上,因为需要计算待分类样本与所有训练样本之间的距离。
- 存储成本高,需要存储全部数据集。
- 对不平衡的数据集表现不佳,可能需要进行采样来平衡数据。
- 对特征尺度敏感,需要进行特征缩放。
应用场景
KNN 算法适用于各种分类和回归问题,尤其是在数据量不是非常大的场景下。它在文本分类、图像识别、推荐系统等领域都有应用。由于其简单性和直观性,KNN 算法常被用作机器学习初学者的第一个算法。
代码实现:使用 KNN 识别手写数字
以下代码展示了如何使用 OpenCV 和 NumPy 实现手写数字识别。
import numpy as np |
代码详解
1. 导入库
import cv2 |
这两行代码导入了必要的库。cv2 是 OpenCV 的库,用于图像处理和机器学习。numpy 是一个强大的数学库,用于处理数组和矩阵。
2. 图像加载与预处理
# 使用绝对路径 |
这里定义了图像文件的路径,并使用 cv2.imread 函数读取图像。cv2.imread 会将图像加载为一个三维数组(高度、宽度、颜色通道)。
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
这行代码将图像从 BGR 颜色空间转换为灰度图像。灰度图像是一个二维数组,每个像素值表示该像素的亮度。
3. 分割数字图像
cells = [np.hsplit(row, 100) for row in np.vsplit(gray, 50)] |
这行代码将灰度图像分割成多个小块(cell)。np.vsplit 将图像垂直分割成 50 行,np.hsplit 将每一行水平分割成 100 列。这样,整个图像被分割成 50×100=5000 个小块,每个小块是一个数字。
x = np.array(cells) |
将分割后的图像块列表转换为一个 NumPy 数组,方便后续操作。
4. 划分训练集和测试集
train = x[:, :50] # 前50列作为训练集 |
将 50×100 的图像块分为训练集和测试集。前 50 列(50×50=2500 个图像块)作为训练集,后 50 列(50×50=2500 个图像块)作为测试集。
5. 数据重塑与类型转换
train = train.reshape(-1, 400).astype(np.float32) |
将训练集和测试集的每个图像块重新调整为一维数组,长度为 400(因为每个图像块是 20×20 像素)。-1 表示自动计算行数,astype(np.float32) 将数据类型转换为浮点数,这是 KNN 算法的要求。
6. 创建标签
k = np.arange(10) |
定义训练集和测试集的标签。np.arange(10) 生成一个从 0 到 9 的数组,表示 10 个数字类别。np.repeat(k, 250) 将每个数字类别重复 250 次,因为每个数字类别有 250 个样本。[:, np.newaxis] 将数组从一维扩展为二维,以满足 KNN 算法的要求。
7. 训练 KNN 模型
knn = cv2.ml.KNearest_create() |
创建一个 KNN 模型,并使用训练集数据和标签进行训练。cv2.ml.ROW_SAMPLE 表示每一行是一个样本。
8. 预测与评估
ret, result, neighbours, dist = knn.findNearest(test, k=5) |
使用训练好的 KNN 模型对测试集进行预测。k=5 表示考虑最近的 5 个邻居。findNearest 返回的结果包括:
ret:返回值,通常不使用。result:预测的标签。neighbours:最近的 5 个邻居的标签。dist:最近的 5 个邻居的距离。
matches = result == test_labels |
计算模型的准确率。result == test_labels 比较预测的标签和真实的标签,生成一个布尔数组。np.count_nonzero(matches) 统计布尔数组中为 True 的个数,即正确预测的样本数。accuracy 计算准确率,公式为:正确预测的样本数 / 总样本数 × 100%。最后打印出准确率。
手写数字数据集示例
下面是我们使用的手写数字数据集示例,包含了 0-9 十个数字的各种手写形式:
常见问题解答
1. 方法链中的点号用法
问:这几句里面的点号的用法是可以一直加在后面的吗:train = train.reshape(-1,400).astype(np.float32)
答:点号(.)在这里并不是随意添加的。在 Python 中,点号用于访问对象的属性或方法。在你提到的代码行中,reshape 和 astype 是 numpy 数组对象的方法。reshape 方法用于改变数组的形状,而 astype 方法用于改变数组的数据类型。这些方法必须通过点号来调用,因为它们是数组对象的一部分。这种连续调用多个方法的方式称为"方法链",是一种常见的编程风格。
2. 标签标注的原理
问:在标签标注的时候,它咋知道每个手写对应的真实值是什么?还是说因为所给手写图片的特殊性决定的?
答:在手写数字识别任务中,通常使用的是已经标注好的数据集,比如 MNIST 数据集。这些数据集中的每个手写数字图像都有一个对应的标签,表示它代表的数字。这些标签是人为提供的,通常是数据集创建者在收集数据时就标注好的。
在你提供的代码中,标签是通过 np.repeat(k,250)[:,np.newaxis] 这行代码生成的,它假设每个数字(0-9)都有 250 个样本,并且每个数字的样本都被重复 250 次来创建标签数组。这种假设成立的前提是原始图像已经被精心排列和组织好了—前 50 列是训练集,后 50 列是测试集,并且每一行都只包含同一个数字的不同书写样本。
3. 数据重塑的原因
问:为什么要调节成一维数组?
答:将图像块调整为一维数组是为了适应机器学习模型的输入要求。大多数机器学习模型,特别是像 KNN 这样的简单模型,都期望输入数据是一维的。这样做可以将每个图像块展平成一个长向量,其中每个元素都是图像中的一个像素值。这样处理后,每个图像块都变成了一个 400 维的向量(20x20 像素的图像块),这使得模型可以更容易地处理和比较这些数据。
总结
KNN 算法是一种简单而有效的机器学习算法,特别适合初学者入门机器学习。通过本文的介绍和代码实现,你应该对 KNN 算法的原理和应用有了基本的了解。手写数字识别是一个经典的机器学习问题,通过这个问题,你可以学习到数据预处理、模型训练和评估等机器学习的基本流程。
虽然 KNN 算法简单易懂,但在处理大规模数据集时可能面临性能瓶颈。因此,在实际应用中,我们需要根据具体需求和数据特点来选择合适的算法和数据结构。




