统计学习方法读书笔记6-K近邻算法及代码实现

1147-柳同学

发表文章数:593

首页 » 算法 » 正文

1.K-近邻算法

统计学习方法读书笔记6-K近邻算法及代码实现

2.K-近邻模型(三个基本要素)

1.距离度量

统计学习方法读书笔记6-K近邻算法及代码实现
统计学习方法读书笔记6-K近邻算法及代码实现

2.K值的选择

统计学习方法读书笔记6-K近邻算法及代码实现

3.分类决策规则

统计学习方法读书笔记6-K近邻算法及代码实现

3.kd树

通过线性扫描实现k近邻算法,当训练集很大时,计算非常耗时

因此,需要考虑如何对训练数据进行快速的k近邻搜索

为了提高k近邻搜索效率,可以考虑使用特殊的结构存储训练数据,以减少计算距离次数

kd树正是这个方法

1.构造平衡kd树

统计学习方法读书笔记6-K近邻算法及代码实现
统计学习方法读书笔记6-K近邻算法及代码实现
P54-55

2.搜索kd树

统计学习方法读书笔记6-K近邻算法及代码实现
统计学习方法读书笔记6-K近邻算法及代码实现
P56-57
统计学习方法读书笔记6-K近邻算法及代码实现

4.K近邻代码实现

#!usr/bin/env python
# -*- coding:utf-8 _*-
"""
@author: liujie
@software: PyCharm
@file: KNN.py
@time: 2020/10/20 22:25
"""
# KNN没有显示的训练过程
import time
import numpy as np
from tqdm import tqdm


def loaddata(filename):
    """
    加载数据
    :param filename:文件路径
    :return: 返回数据集与标签
    """
    print('start to read file')
    # 存放数据与标签
    dataArr = []
    labelArr = []
    # 打开文件
    fr = open(filename)
    # 循环读取文件每一行
    for line in tqdm(fr.readlines()):
        # 获取当前行,并存放入列表中
        # strip:去掉每行字符串首尾指定的字符(默认空格或换行符)
        # split:按照指定的字符将字符串切割成每个字段,返回列表形式
        currentLine = line.strip().split(',')
        # 存放数据并转换成整型
        dataArr.append([int(num) for num in currentLine[1:]])
        # 存放标签并转换成整型
        labelArr.append(int(currentLine[0]))

    return dataArr, labelArr

def calDist(x1,x2):
    """
    计算欧式距离
    :param x1: 向量1
    :param x2: 向量2
    :return: 欧式距离
    """
    # 欧式距离
    return np.sqrt(np.sum(np.square(x1 - x2)))
    # 马哈顿距离
    # return np.sum(x1 - x2)

def getClosest(trainDataMat, trainLabelMat, x, topK):
    """
    预测x的标记
    多数表决
    :param trainDataMat:训练数据集
    :param trainLabelMat: 训练数据标签
    :param x: 预测样本x
    :param topK: 选择参考最邻近样本的数目
    :return: 预测的标记
    """

    # 建立一个存放向量x与每个训练集中样本距离的字典
    distDict = {}
    # 遍历训练集中所有样本点,计算与x的距离
    for i in tqdm(range(len(trainDataMat))):
        # 获取向量
        xi = trainDataMat[i]
        # 计算距离
        curDist = calDist(x,xi)
        # 将距离放入对应的字典位置
        distDict[i] = curDist

    # 对字典的value进行排序-升序排序-字典无法排序,但是可以建立有序的数据类型来表示排序后的值,该值将是一个列表-可能是一个元组列表
    # argsort:函数将数组的值从小到大排序后,并按照其相对应的索引值输出--列表
    # sorted:sorted(iterable,key,reverse),sorted一共有iterable,key,reverse这三个参数
    dist_list_topK = sorted(distDict.items(),key = lambda x:x[1],reverse=False)[:topK]
    # 转变成字典
    dist_dict_topk = dict(dist_list_topK)
    # print(dist_dict_topk)
    # dist_dict_topK中key进行循环
    labelLict = [0] * 10
    for index in dist_dict_topk:
        # 找到最近topk中的标签,并进行多数表决投票
        labelLict[int(trainLabelMat[index])] += 1

    # 找到选票箱中票数最多的票数值
    return np.argsort(np.array(labelLict))[-1]

# 定义计算正确率的函数
def model_test(trainData,trainLabel,testData,testLabel,topK):
    """
    测试正确率
    :param trainData:训练数据集
    :param trainLabel: 训练标签
    :param testData: 测试数据集
    :param testLabel: 测试标签
    :param topK: 选择多少个临近点参考
    :return: 正确率
    """
    print('start to test')
    # 将列表转化为矩阵,方便并行运算
    trainDataMat = np.mat(trainData)
    trainLabelMat = np.mat(trainLabel).T
    testDataMat = np.mat(testData)
    testLabelMat = np.mat(testLabel).T

    # 错误计数
    errorCnt = 0
    # 遍历测试集,对每个测试集样本进行测试
    # 由于计算向量与向量之间的时间耗费太大,测试集中每个样本都要计算与60000个样本的距离,所以这里人为改成了200个
    # for i in range(len(testDataMat)):
    #    print('test %d : %d'%(i,len(testDataMat)))
    for i in range(200):
        print('test %d : %d'%(i,200))
        # 获取测试向量与标签
        x = testDataMat[i]
        y = getClosest(trainDataMat,trainLabelMat,x,topK)

        # 预测标记与实际标记不符,错误计数加1
        if y != testLabelMat[i]:errorCnt += 1

    return 1 - errorCnt / 200




if __name__ == '__main__':
    start = time.time()
    # 加载数据集
    trainData, trainLabel = loaddata('data/mnist_train.csv')
    testData, testLabel = loaddata('data/mnist_test.csv')

    # 计算测试集的正确率
    accur = model_test(trainData,trainLabel,testData,testLabel,25)
    # 打印正确率
    print('accur : %d'%(accur*100),'%')

    end = time.time()
    print('time span:',end-start)
    
    
accur : 97 %
time span: 307.68515515327454

未经允许不得转载:作者:1147-柳同学, 转载或复制请以 超链接形式 并注明出处 拜师资源博客
原文地址:《统计学习方法读书笔记6-K近邻算法及代码实现》 发布于2020-10-21

分享到:
赞(0) 打赏

评论 抢沙发

评论前必须登录!

  注册



长按图片转发给朋友

觉得文章有用就打赏一下文章作者

支付宝扫一扫打赏

微信扫一扫打赏

Vieu3.3主题
专业打造轻量级个人企业风格博客主题!专注于前端开发,全站响应式布局自适应模板。

登录

忘记密码 ?

您也可以使用第三方帐号快捷登录

Q Q 登 录
微 博 登 录