> 技术文档 > 机器学习基础-k 近邻算法(从辨别水果开始)

机器学习基础-k 近邻算法(从辨别水果开始)

一、生活中的 \"分类难题\" 与 k 近邻的灵感

你有没有这样的经历:在超市看到一种从没见过的水果,表皮黄黄的,拳头大小,形状圆滚滚。正当你犹豫要不要买时,突然想起外婆家的橘子好像就是这个样子 —— 黄色、圆形、大小和拳头差不多。于是你推断:\"这应该是橘子吧!\"

其实,这个看似平常的判断过程,竟然藏着机器学习中最经典的分类算法 ——k 近邻(k-Nearest Neighbors,简称 kNN)的核心思想!

1.1 现实中的解法拆解

当我们判断未知水果时,大脑会自动完成三个步骤:

  1. 收集特征:观察颜色(黄色)、形状(圆形)、大小(拳头大)
  1. 匹配经验:调动记忆中 \"橘子\" 的特征库(黄色、圆形、拳头大)
  1. 做出判断:因为新水果的特征和记忆中的橘子最像,所以归类为橘子

这和 k 近邻算法的工作流程惊人地相似!唯一的区别是:计算机需要我们把这些 \"看得到的特征\" 变成 \"算得出的数据\"。

1.2 k 近邻算法的有趣灵魂

k 近邻算法的有趣之处在于它的 \"懒惰\" 和 \"实在\":

  • 懒惰:它不像其他算法那样先总结规律(比如 \"黄色圆形水果都是橘子\"),而是等到需要判断时才去比对已知数据
  • 实在:它的判断逻辑简单粗暴 ——\"少数服从多数\",看新样本周围最像的 k 个样本里哪种类型占多数

就像你纠结新水果是橘子还是苹果时,会找 5 个见过这两种水果的人投票,哪种意见多就信哪种。

二、从生活到代码:k 近邻算法的实现之路

我们用一个具体案例来实现:根据 \"颜色深度\"(0-10,数值越大越黄)和 \"大小\"(0-10,数值越大越大)两个特征,判断水果是橘子(标签 1)还是苹果(标签 0)。

2.1 准备数据:把生活观察变成数字

首先,我们需要把已知的水果数据整理成计算机能理解的格式:

# 导入必要的库

import numpy as np # 用于数值计算

import matplotlib.pyplot as plt # 用于画图

# 已知水果数据:[颜色深度, 大小],标签:0=苹果,1=橘子

# 想象这些数据来自我们之前见过的水果:

# 苹果通常偏红(颜色深度小),大小不一;橘子偏黄(颜色深度大)

known_fruits = np.array([

[2, 3], # 苹果:颜色偏红(2),小个(3)

[3, 4], # 苹果:颜色较红(3),中个(4)

[1, 5], # 苹果:颜色很红(1),大个(5)

[7, 6], # 橘子:颜色较黄(7),中个(6)

[8, 5], # 橘子:颜色很黄(8),中个(5)

[9, 4] # 橘子:颜色极黄(9),小个(4)

])

# 对应的标签:0代表苹果,1代表橘子

labels = np.array([0, 0, 0, 1, 1, 1])

# 未知水果:颜色深度6,大小5(就是我们在超市看到的那个)

unknown_fruit = np.array([6, 5])

2.2 数据可视化:让计算机 \"看见\" 差异

我们用散点图把数据画出来,直观感受苹果和橘子的特征差异:

# 绘制已知水果

plt.scatter(known_fruits[labels==0, 0], known_fruits[labels==0, 1],

color=\'red\', marker=\'o\', label=\'苹果\') # 苹果标为红色圆点

plt.scatter(known_fruits[labels==1, 0], known_fruits[labels==1, 1],

color=\'orange\', marker=\'o\', label=\'橘子\') # 橘子标为橙色圆点

# 绘制未知水果(用五角星标记)

plt.scatter(unknown_fruit[0], unknown_fruit[1],

color=\'purple\', marker=\'*\', s=200, label=\'未知水果\') # 紫色五角星,放大显示

# 加上坐标轴标签和标题

plt.xlabel(\'颜色深度(0-10,数值越大越黄)\')

plt.ylabel(\'大小(0-10,数值越大越大)\')

plt.title(\'水果特征分布图\')

plt.legend() # 显示图例

plt.show() # 展示图像

运行这段代码,你会看到:红色圆点(苹果)集中在左侧(颜色偏红),橙色圆点(橘子)集中在右侧(颜色偏黄),而紫色五角星(未知水果)刚好在橘子群附近 —— 这就是我们肉眼判断的依据!

三、k 近邻算法的核心步骤:用数学实现 \"投票选举\"

计算机怎么判断未知水果的类别呢?它会执行四个关键步骤,我们一步步用代码实现:

3.1 第一步:计算距离(谁离我最近?)

生活中我们靠 \"感觉\" 判断相似,计算机则靠 \"距离\" 计算。最常用的是欧氏距离(就像直尺测量两点距离):

\\(distance = \\sqrt{(x_1-x_2)^2 + (y_1-y_2)^2}\\)

用代码实现这个计算:

def calculate_distance(known_point, unknown_point):

\"\"\"

计算两个点之间的欧氏距离

参数:

known_point:已知点的特征(如[2,3])

unknown_point:未知点的特征(如[6,5])

返回:

两点之间的距离

\"\"\"

# 计算每个特征的差值平方,再求和,最后开平方

squared_diff = (known_point[0] - unknown_point[0])**2 + (known_point[1] - unknown_point[1])** 2

distance = np.sqrt(squared_diff)

return distance

# 计算未知水果与每个已知水果的距离

distances = []

for fruit in known_fruits:

dist = calculate_distance(fruit, unknown_fruit)

distances.append(dist)

# 打印计算过程,方便理解

print(f\"已知水果特征{fruit}与未知水果的距离:{dist:.2f}\")

运行后会得到类似这样的结果:

已知水果特征[2 3]与未知水果的距离:4.47

已知水果特征[3 4]与未知水果的距离:3.16

已知水果特征[1 5]与未知水果的距离:5.10

已知水果特征[7 6]与未知水果的距离:1.41 # 这个最近!

已知水果特征[8 5]与未知水果的距离:2.00

已知水果特征[9 4]与未知水果的距离:3.61

3.2 第二步:找邻居(选 k 个最像的)

k 近邻算法中的 \"k\" 就是要选的邻居数量。比如 k=3,就是找距离最近的 3 个已知水果:


# 把距离和对应的标签组合起来,方便排序

distance_with_label = list(zip(distances, labels))

# 按距离从小到大排序

sorted_distance = sorted(distance_with_label, key=lambda x: x[0])

# 选择k=3个最近的邻居

k = 3

nearest_neighbors = sorted_distance[:k]

print(f\"\\n距离最近的{k}个邻居是:\")

for dist, label in nearest_neighbors:

fruit_type = \"橘子\" if label == 1 else \"苹果\"

print(f\"距离{dist:.2f},类别:{fruit_type}\")

此时会输出:

距离最近的3个邻居是:

距离1.41,类别:橘子

距离2.00,类别:橘子

距离3.16,类别:苹果

3.3 第三步:投票表决(少数服从多数)

看看这 3 个邻居里哪种水果占多数:

# 提取邻居的标签

neighbor_labels = [label for (dist, label) in nearest_neighbors]

# 统计每个标签出现的次数

label_counts = np.bincount(neighbor_labels)

# 找到出现次数最多的标签

predicted_label = np.argmax(label_counts)

# 输出结果

if predicted_label == 1:

print(\"\\n根据k近邻算法判断,这个未知水果是:橘子!\")

else:

print(\"\\n根据k近邻算法判断,这个未知水果是:苹果!\")

最终结果会显示 \"橘子\",和我们的直觉判断完全一致!

四、完整代码:可直接运行的 k 近邻分类器

把上面的步骤整合起来,再加上一些优化,就得到了一个完整的 k 近邻分类器:

import numpy as np

import matplotlib.pyplot as plt

class SimpleKNN:

\"\"\"简单的k近邻分类器\"\"\"

def __init__(self, k=3):

\"\"\"

初始化分类器

参数:

k:要选择的邻居数量,默认3个

\"\"\"

self.k = k

self.known_data = None # 用于存储已知数据

self.known_labels = None # 用于存储已知标签

def fit(self, X, y):

\"\"\"

训练模型(其实就是记住已知数据)

参数:

X:已知样本的特征数据,形状为[样本数, 特征数]

y:已知样本的标签,形状为[样本数]

\"\"\"

self.known_data = X

self.known_labels = y

print(f\"模型训练完成,记住了{len(X)}个样本\")

def predict(self, X):

\"\"\"

预测新样本的类别

参数:

X:新样本的特征数据,形状为[特征数]

返回:

预测的标签

\"\"\"

# 计算与所有已知样本的距离

distances = []

for data in self.known_data:

# 计算欧氏距离

dist = np.sqrt(np.sum((data - X) **2))

distances.append(dist)

# 把距离和标签绑定,按距离排序

distance_with_label = list(zip(distances, self.known_labels))

sorted_distance = sorted(distance_with_label, key=lambda x: x[0])

# 取前k个邻居的标签

nearest_labels = [label for (dist, label) in sorted_distance[:self.k]]

# 少数服从多数

return np.argmax(np.bincount(nearest_labels))

# ----------------------

# 用水果数据测试我们的分类器

# ----------------------

if __name__ == \"__main__\":

# 已知水果特征:[颜色深度, 大小]

fruits = np.array([

[2, 3], [3, 4], [1, 5], # 苹果(标签0)

[7, 6], [8, 5], [9, 4] # 橘子(标签1)

])

labels = np.array([0, 0, 0, 1, 1, 1])

# 创建分类器,选择5个邻居(试试把k改成1或5,看结果会不会变)

knn = SimpleKNN(k=5)

# 训练模型(其实就是记住数据)

knn.fit(fruits, labels)

# 要预测的未知水果:颜色深度6,大小5

unknown_fruit = np.array([6, 5])

prediction = knn.predict(unknown_fruit)

# 输出结果

fruit_names = {0: \"苹果\", 1: \"橘子\"}

print(f\"\\n未知水果的特征:颜色深度{unknown_fruit[0]},大小{unknown_fruit[1]}\")

print(f\"预测结果:这是一个{fruit_names[prediction]}!\")

# 画图展示

plt.scatter(fruits[labels==0, 0], fruits[labels==0, 1],

color=\'red\', marker=\'o\', label=\'苹果\')

plt.scatter(fruits[labels==1, 0], fruits[labels==1, 1],

color=\'orange\', marker=\'o\', label=\'橘子\')

plt.scatter(unknown_fruit[0], unknown_fruit[1],

color=\'purple\', marker=\'*\', s=200, label=\'未知水果\')

plt.xlabel(\'颜色深度(0-10,越大越黄)\')

plt.ylabel(\'大小(0-10,越大越大)\')

plt.title(f\'k={knn.k}的k近邻分类结果\')

plt.legend()

plt.show()

五、k 近邻算法的关键知识点

5.1 如何选择最佳的 k 值?

k 值是 k 近邻算法中最重要的参数:

  • k 太小:容易被噪声干扰(比如刚好有个奇怪的苹果长得像橘子)
  • k 太大:会把不相关的样本也算进来(比如远在天边的苹果也参与投票)

一个简单的方法是:从 k=3 开始尝试,逐渐增大,看哪个 k 值的预测效果最好。

5.2 特征需要 \"标准化\"

生活中如果特征的单位不一样(比如一个特征是厘米,一个是千克),会影响距离计算。解决办法是标准化:

# 标准化特征:让每个特征的平均值为0,标准差为1

def standardize(X):

return (X - np.mean(X, axis=0)) / np.std(X, axis=0)

5.3 k 近邻的优缺点

优点

  • 简单易懂,几乎不用数学基础就能理解
  • 不需要提前训练模型,拿到新数据可以直接用
  • 可以处理多种类型的数据

缺点

  • 数据量大的时候,计算距离会很慢
  • 对特征的数量敏感(特征太多时会 \"迷路\")

六、动手实践:用 scikit-learn 实现更专业的 k 近邻

真实项目中,我们会用成熟的库来实现 k 近邻。试试用 scikit-learn(Python 最流行的机器学习库)重写上面的水果分类:

# 安装scikit-learn(如果没安装的话)

# !pip install scikit-learn

from sklearn.neighbors import KNeighborsClassifier

import numpy as np

# 数据准备(和之前一样)

fruits = np.array([

[2, 3], [3, 4], [1, 5], # 苹果

[7, 6], [8, 5], [9, 4] # 橘子

])

labels = np.array([0, 0, 0, 1, 1, 1])

# 创建k近邻分类器,k=3

knn = KNeighborsClassifier(n_neighbors=3)

# 训练模型

knn.fit(fruits, labels)

# 预测未知水果

unknown_fruit = np.array([[6, 5]]) # 注意这里要写成二维数组

prediction = knn.predict(unknown_fruit)

print(\"scikit-learn预测结果:\", \"橘子\" if prediction[0]==1 else \"苹果\") # 输出\"橘子\"

是不是更简单了?这就是专业库的力量!

七、总结:一篇博客掌握 k 近邻

通过辨别水果的例子,我们学会了:

  1. k 近邻算法的核心思想:\"看邻居投票\"
  1. 实现步骤:计算距离→找邻居→投票表决
  1. 关键参数 k 的选择方法
  1. 如何用代码实现(从手写简单版本到专业库)

k 近邻就像机器学习世界的 \"Hello World\",它简单却蕴含了机器学习的基本思想 ——从数据中找规律。下一次当你在超市辨别水果时,不妨想想:\"这个过程如果写成代码,应该怎么实现呢?\"

现在就动手修改代码里的参数(比如 k 值、水果特征),看看会得到什么有趣的结果吧!

祝你的机器学习之旅,从这个甜甜的 \"橘子分类器\" 开始,越来越精彩!

  还想看更多,来啦!!!

1,大数据比赛篇全国职业院校技能大赛-大数据比赛心得体会_全国职业职业技能比赛 大数据-CSDN博客

2,求职简历篇(超实用)大学生简历写作指南:让你的简历脱颖而出-CSDN博客

3,AIGC心得篇aigc时代,普通人需要知道的-CSDN博客

4,数据分析思维篇学习数据分析思维的共鸣-CSDN博客

5,中年危机篇“中年危机”如何转变为“中年机遇”-CSDN博客