2021腾讯数字生态大会直播预约通道开启!技术内容大爆发,开发者必看! 了解详情
写点什么

工程师要理解机器学习算法的复杂性

2019 年 12 月 31 日

工程师要理解机器学习算法的复杂性

本文阐述了理解机器学习算法内部工作原理的重要性,以及它在实现和评估方面的不同之处。


机器学习工程师经常会发现,他们总是需要为手头的问题选择正确的算法。通常情况下是这样做的:首先,他们需要了解他们提供解决方案的问题的结构。然后,他们研究手头的数据集。经过初步观察后得到关键结论,最终,他们为任务选择正确的算法。


确定适用于手头数据集的最佳算法似乎是一项普通的工作。通常在这种情况下,工程师们会试图走捷径。如果任务有 0-1 标签,那么只需应用逻辑回归就可以了,对吗?不对!我们应该意识到这些捷径,并时刻提醒自己,尽管某些算法可以很好地解决特定问题,但是,为解决问题选择最佳算法时,却没有什么诀窍。然而,工程师们应该始终讨论并考虑算法的复杂性和运行时分析。


算法的运行时分析不仅对理解算法的内部工作至关重要,而且还可以产生更为成功的实现。


在本文中,我将描述一种情况,在这一情况下,忽略算法的运行时分析,K 表示聚类,这些都是会让工程师浪费大量的时间和精力的东西。


什么是 K 均值聚类

K 均值聚类(K-means clustering)是目前最流行、最容易实现的无监督机器学习算法之一。它也是一种易于理解的机器学习算法。


通常情况下,无监督机器学习算法仅使用特征向量对输入数据集中做出推理(Inferences)。因此,这些算法适用于没有标签数据的数据集。当人们想从大量结构化和非结构化数据中提取价值或洞见时,它们也非常有用。K 均值聚类是其中一种探索性数据分析技术,其目标是提取数据点的子群(subgroup),以使同一聚类中的数据点在定义的特征方面非常相似。


K 均值聚类的工作原理

K 均值聚类算法是从随机选择的第一组数据点开始,这些数据点被用作质心(Centroid)的初始种子。然后,该算法执行迭代计算,将其余的数据点分配给最近邻的聚类。当根据定义的距离函数执行这些计算时,质心的位置将会更新。在以下任一情况下,它都会停止对聚类中心的优化:


  • 质心的位置是稳定的,即,它们的值的变化不超过预定义的阈值。

  • 该算法超过了最大迭代次数。


因此,该算法的复杂度为:


O(n * K * I * d)n : number of pointsK : number of clustersI : number of iterationsd : number of attributes
复制代码


K 均值算法示例问题

我将分享一段 K 均值聚类算法任务的代码段。我唯一目的在于为读者演示一个示例,在这个实例中,如果不能理解运行时的复杂性,那么将会导致对算法的评估很差劲。需要说明的是,我所采取的步骤并没有针对算法进行优化,也就是说,为了取得更好的结果,你可以对数据进行预处理并获得更好的聚类。所涉及的步骤概述如下:


  1. 导入库,并读取数据集。在这个示例中,我导入相关的库并读书数据集,这些数据集已经下载到本地文件夹中。

  2. 预处理。在这一步中,我丢弃了字符串类型的列,而只关注数值特性。由于 K 均值聚类算法计算数据点之间的距离,所以它适用于数值列。

  3. 应用主成分分析法进行降维。在应用 K 均值聚类算法之前,最好先对数据集进行降维,因为在高维空间中,距离度量的效果并不是很好。

  4. 计算轮廓分数。K 均值聚类算法并不能直接应用。它涉及到寻找聚类最佳数量的问题。轮廓分数(Silhouette score)是可以用来确定聚类最佳数量的技术之一。如果不能理解轮廓分数分析所涉及的计算的复杂性,将会得到较差的实现效果。

  5. 其他解决方案。在本文中,我列出了一些可供选择的解决方案,以找到聚类的最佳数量。在运行时复杂性方面,它们与轮廓分数相比,更具优势。


你可以重现这个问题来亲自尝试。数据集的地址为:https://www.kaggle.com/sobhanmoosavi/us-accidents。


第一步:导入库并读取数据库

import pandas as pdimport numpy as npimport matplotlib.pyplot as pltfrom matplotlib.pyplot import figurefrom sklearn.cluster import KMeansfrom sklearn import metricsdf = pd.read_csv(file)
复制代码


第二步:预处理

#Remove columns that have almost all of its value as Nonedf.drop(['End_Lat','End_Lng','Number','Wind_Chill(F)','Precipitation(in)'],inplace = True,axis=1)#Change type of boolean columns to integerfor column in df.columns.values:    if df[column].dtype == bool:        df[column] = df[column].astype(int)def handleMissingData(data):   for column in data.columns.values:    if column not in ['ID'] and data[column].isna().sum():      if data[column].dtype == int or data[column].dtype == float:        data[column].fillna(data[column].mean(),inplace=True)      else:        data[column].fillna(data[column].mode()[0],inplace=True)handleMissingData(df)df = df[['Amenity','Crossing','Junction','No_Exit','Railway','Station','Stop','Traffic_Signal']]
复制代码


handleMissingData(df)df = df[['Amenity','Crossing','Junction','No_Exit','Railway','Station','Stop','Traffic_Signal']]
复制代码


我们只根据与路径有关的特征对数据点进行聚类,以便进行说明。


第三步:应用主成分分析进行降维

import matplotlib.pyplot as pltfrom matplotlib.pyplot import figurefrom sklearn.decomposition import PCAdata_points = df[df.columns.values].valuesnumberOfPCAComponent = 3pcaComponents = []for i in range(numberOfPCAComponent):    pcaComponents.append("PCA"+"_"+str(i+1))component_pos = np.arange(len(pcaComponents))sklearn_pca = PCA(n_components=numberOfPCAComponent)sklearn_pca1 = PCA()sklearn_pca1.fit(data_points)data_points = sklearn_pca.fit_transform(data_points)explained_variance = sklearn_pca.explained_variance_ratio_  print(explained_variance,"explained_variance")plt.bar(component_pos, explained_variance, align='center', alpha=0.5)plt.xticks(component_pos, pcaComponents)plt.ylabel('Explained Variance')plt.title('Principle Components')plt.show()plt.savefig('explained_variance.png')plt.plot(range(0,8), sklearn_pca1.explained_variance_, 'bx-')plt.xlabel('pca component')plt.ylabel('explained variances')plt.title('Elbow Method For Optimal Number of PCA Component')plt.show()plt.savefig('elbow_method_optimal_number_pca.png')pickle.dump(sklearn_pca,open("pca_stage_semi-supervised.pkl","wb"))
复制代码




看起来 3 是最佳的。


第四步:计算轮廓得分

确定最佳聚类数目有许多指标和方法。但我会集中讨论其中的几个。轮廓得分就是这些指标度量之一。它使用每个实例的平均聚类内距离和平均最近邻聚类距离来计算的。它计算每个样本和相应聚类中其余样本之间的距离。因此,它的运行时复杂度为 O(n²)。如果无法执行运行时分析的话,你可能需要等待数小时(如果不是数天)才能完成对大型数据集的分析。由于当前数据集有数百万行,解决方法可以是使用更简单的指标度量,如惯性或对数据集应用随机采样。我将阐述这两种方法。


备选解决方案

  • 肘部法则


该方法使用惯性或聚类内平方和作为输入。它描述了惯性值随聚类数量的增加而减小的情况。“肘部”(Elbow,曲线上的拐点)就是一个很好的指示点,在该点上惯性值的减小并不会发生明显的变化。使用这种技术的优点是,聚类内平方和在计算上不像轮廓得分那样昂贵,并且已经作为度量包含在算法中。


%%timeit from sklearn.cluster import KMeansfrom sklearn import metricsSum_of_squared_distances = []K = range(2,6)for k in K:    km = KMeans(n_clusters=k,random_state=5)    km = km.fit(data_points)    Sum_of_squared_distances.append(km.inertia_)
复制代码



上面这段代码的挂钟时间是:


27.9 s ± 247 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
复制代码


  • 随机降采样


降采样允许你可以处理更小的数据集。这样做的好处是,算法完成所需的时间大为减少。如此一来,分析人员就能够更快地进行工作。缺点是降采样如果随机进行的话,可能无法代表原始数据集。因此,任何涉及降采样数据集的分析都可能导致不准确的结果。但是,你始终可以采取预防措施来确保降采样的数据集代表原始数据集。


%%timeitfrom sklearn.cluster import KMeansfrom sklearn import metrics#apply k means clustering and measure silhouette score kmeans = KMeans(n_clusters=4,random_state=5).fit(data_points)metrics.silhouette_score(data_points,kmeans.labels_,metric='euclidean',sample_size=100000)
复制代码


上面代码段的挂钟时间为:


3min 25s ± 640 ms per loop (mean ± std. dev. of 2 runs, 1 loop each)
复制代码


结语

在本文中,我试图强调理解机器学习算法复杂性的重要性。算法的运行时分析不仅对于特定任务中的算法选择至关重要,而且对算法的成功实现也很重要。这也是大多数雇主在数据科学领域中寻找的关键技能之一。因此,进行运行时分析并理解算法的复杂性,一直就是很好的实践。


作者介绍:


Baran Köseoğlu,软件开发人员,数据科学家。对机器学习、人工智能和数据科学感兴趣。著有《Towards Data Science》一书。


原文链接:


https://towardsdatascience.com/importance-of-understanding-the-complexity-of-a-machine-learning-algorithm-9d0532685982


2019 年 12 月 31 日 21:071252

评论

发布
暂无评论
发现更多内容

架构实战营模块二作业

老猎人

架构实战营

性能测试误差分析文字版-上

FunTester

性能测试 自动化测试 接口测试 测试框架 测试开发

【硬刚Kylin】Kylin入门/原理/调优/OLAP解决方案和行业典型应用

王知无

【LeetCode】变位词组Java题解

HQ数字卡

算法 LeetCode 7月日更

2.2如何设计高性能架构

Lemon

高性能架构

在线XML转JSON工具

入门小站

Spring源码解析 -- SpringWeb请求参数获取解析

Java spring 源码解析

性能测试误差统计实践

FunTester

软件测试 测试 性能测试 测试开发

面试算法之螺旋数组查找问题

泽睿

面试 二分查找

架构实战营 - 模块二作业: 分析微信朋友圈的高性能复杂度

Julian Chu

#架构实战营

【Flutter 专题】83 解自定义 ACEWave 波浪 Widget (一)

阿策小和尚

Flutter 小菜 0 基础学习 Flutter Android 小菜鸟 7月日更

基于P4的SCION -- 构建太比特的未来互联网

俞凡

网络

Selenium4 Alpha-7升级体验

FunTester

maven 自动化测试 Gradle 测试开发 selenium

【Java特性专题】JDK(8-11)特性分布变化简介

浩宇天尚

Java Java新特性 7月日更 Java11

生产环境踩坑系列::Hive on Spark的connection timeout 问题

dclar

spark hive hive on spark

架构实战营模块二作业

maybe

external-resizer源码分析-pvc扩容分析

良凯尔

Kubernetes 源码分析 Ceph CSI Kubernetes Plugin

[架构实战营][模块二作业]

KK_TTN

架构实战营

[架构实战营一期] 模块二作业

trymorewang

架构实战营

智能运维系列之五:总结

micklongen

微信朋友圈架构设计

summer

极客时间 极客时间架构师一期

架构实战营模块二作业

袁小芬

架构实战营

Go语言:sync包控制并发详解!

微客鸟窝

Go

2.3如何设计高可用架构

Lemon

存储高可用

Vue进阶(幺叁贰):ES数组操作:数组合并

No Silver Bullet

Vue 7月日更 数组合并

Vue进阶(幺幺捌):CSS3 - 选择器first-child、last-child、nth-child、nth-last-child、nth-of-type

No Silver Bullet

Vue 7月日更

你真的了解 Session 和 Cookie 吗?

陈皮的JavaLib

Java HTTP session Cookie

【架构设计模块二】:微信朋友圈的高性能复杂度

Ryoma

架构实战营

性能测试误差分析文字版-下

FunTester

软件测试 性能测试 接口测试 测试框架 测试开发

架构实战营 - 模块二作业

思梦乐

一种简单可落地的分布式事务实践方案,面试问起来也不慌了

JAVA前线

Java 数据库 分布式事务

英特尔On技术创新峰会

英特尔On技术创新峰会

工程师要理解机器学习算法的复杂性-InfoQ