阿里云「飞天发布时刻」2024来啦!新产品、新特性、新能力、新方案,等你来探~ 了解详情
写点什么

使用开源概率编程语言 Pyro 对截尾时间 - 事件数据进行建模

  • 2019-06-08
  • 本文字数:5588 字

    阅读完需:约 18 分钟

使用开源概率编程语言Pyro对截尾时间-事件数据进行建模

在 Uber,我们有兴趣调查乘客在平台上完成首次乘坐到第 2 次乘坐之间的时间跨度。我们的很多乘客是通过推荐或促销活动首次与 Uber 进行互动的。他们的第 2 次乘坐是个关键指标,表明乘客在使用平台的过程中发现价值并愿意长期使用我们服务。然而,对第 2 次乘坐时间建模是件棘手的事。例如,一些乘客不经常乘车。在分析这类乘客的第 2 次乘坐之前的时间-事件数据时,我们认为他们的数据就是截尾数据。


在其他公司和行业中都存在类似的情况。例如,假设某个电商网站对客户经常性购买模式感兴趣。但是,由于客户行为模式的多样性,该公司也许无法观察到所有客户的所有经常性购买行为,从而导致截尾数据的产生。


在另一个例子中,假设某个广告公司对其用户的重复点击行为感兴趣。由于每个用户的兴趣不同,该公司无法观察到其用户的所有点击行为。用户也许在研究结束后才点击广告。这样就会产生到下一次点击数据的截尾时间。


在截尾的时间-事件数据建模中,对用 索引的每个感兴趣的个体,我们都可以以下面的形式观察数据:


其中,是截尾标识。如果观察到感兴趣的事件,那么;如果感兴趣的事件截尾,那么。当时, 表示感兴趣的时间-事件。当,那么代表截尾发生之前的时间长度。


我们继续讲 Uber 的第 2 次乘坐时间的例子:如果某个乘客在其首次乘坐 12 天后才进行第 2 次乘坐,那么该观察就记录为(12,1)。在另一种情况下,如果某个乘客在首次乘坐后过去了 60 天,并且在给定的截止日期前还没返回到应用程序进行第 2 次乘坐,那么该观察就记录为(60,0)。这种情形如下图所示:



在该领域有大量的分析文献,并且研究时间已经有一个多世纪之久;其中大部分可以用统计编程框架进行简化。在本文中,我们将介绍如何使用Pyro概率编程语言来为截尾的时间-事件数据建模。

与流失建模之间的关系

在我们继续之前,值得一提的是,很多行业从业者通过人为设置“流失”标签的方式来规避截尾的时间-事件数据的挑战。例如,如果一家电商的客户在过去 40 天中没有回到网站进行另一次购买,那么该电商可以把该客户定位为“流失”。


流失建模使得从业者把观察转换为经典的二元分类模式。因此,流失建模就会像使用 scikit-learn 和 XGBoost 这样的现成工具那么简单。例如,上述的两位乘客将分别被标注为“未流失”和“流失”。


尽管流失模型在特定情形下是可行的,但其不一定适用于 Uber。例如,某些乘客只在出差时使用 Uber。如果该假设的乘客每 6 个月出一次差,那么我们最终就会把该商务乘客误标注成“流失”。因此,我们从流失模型中提取的结论可能产生误导。


我们也有兴趣从这些模型中进行解释,以阐明不同因素对观察到的用户行为的影响。因此,模型不应该是个黑匣子。我们希望能够开放该模型并用它做出更明智的业务决策。


为了实现这一点,我们可以将 Pyro 这一灵活且富有表现力的开源工具用于概率编程。

用于统计建模的 Pyro

创建于 Uber 的 Pyro 是用 Python 编写的通用概率编程语言,构建于 PyTorch 张量计算库的基础之上。


如果你具有最小贝叶斯建模知识的统计背景,或是你一直在用 TensorFlow 或 PyTorch 这样的深度学习工具,那么你的运气很好。


下表总结了一些最受欢迎的概率编程项目:



下面,我们将重点介绍这些不同软件项目的一些关键特性:


  1. BUGS/JAGS 是概率编程早期的例子。在统计领域,它们已经被积极开发和使用了 20 多年。

  2. 但是,BUGS/JAGS 主要是从头设计和开发的。因此,模型规范是用它们特定于域的语言完成的。此外,概率程序开发人员需要从 R 和 MATLAB 中的包装器中调用 BUGS/JAGS。用户必须在编码语言和文件之间来回切换,不太方便。

  3. PyMC 依赖于 Theano 后端。但是,Theano 项目最近停止了。

  4. TensorFlow Probability(TFP)最初作为一个名为 Edward 的项目启动。该 Edward 项目已纳入 TFP 项目。

  5. TFP 使用 TensorFlow 作为其计算引擎。因此,其仅支持静态计算图。

  6. Pyro 使用 PyTorch 作为计算引擎,因此支持动态计算图。这使得用户能够在数据流方面指定不同的模型,非常灵活。


简而言之,Pyro 基于最强大的深度学习工具链(PyTorch),同时具有数十年统计研究的支持。因而它是一种非常简洁和强大、但又灵活的概率建模语言。

对截尾的时间-事件数据建模

现在,让我们深入研究如何为时间-事件数据建模。感谢谷歌 Colab,用户得以无需安装 Pyro 和 PyTorch 就可以查看大量代码示例并开始为数据建模。我们甚至可以复制工作簿并在其上进行各种尝试。

模型定义

鉴于本文的目的,我们把时间-事件数据定义为,其中表示时间-事件,表示二进制截尾标签。我们把实际的时间-事件定义为,它可以是没有观察到的。为了简单起见,我们把截尾时间定义为, 并假设它是个已知的固定数字。综上所述,我们可以把这关系建模为:



我们假设遵循带有尺度参数的指数分布,变量与感兴趣的预测因子存在以下线性关系:



其中,是个 softplus 函数,从而确保保持为正。最后,我们假设和遵循正态分布作为先验分布。鉴于本文的目的,我们感兴趣的是评估和的后验分布。

生成人工数据

首先,我们导入所有必要的 Python 包:


import pyro import torch import seaborn as sns import pyro.distributions as dist from pyro import infer, optimfrom pyro.infer.mcmc import HMC, MCMCfrom pyro.infer import EmpiricalMarginal
assert pyro.__version__.startswith('0.3')
复制代码


为了生成实验数据,我们运行以下几行脚本:


n = 500a = 2b = 4c = 8
x = dist.Normal(0, 0.34).sample((n,)) # Note [1]
link = torch.nn.functional.softplus(torch.tensor(a*x + b))# note below, param is rate, not meany = dist.Exponential(rate=1 / link).sample()
truncation_label = (y > c).float()
y_obs = y.clamp(max=c)
sns.regplot(x.numpy(), y.numpy())sns.regplot(x.numpy(), y_obs.numpy()) ## Note [2]
复制代码


恭喜你!你刚刚在 Note[1]所在的行运行了你的第一个 Pyro 函数。在这里,我们从正态分布中采了样。细心的用户也许已经注意到,这种直观的操作和我们在 Numpy 中的工作流程非常相似。


在上述代码段的末尾(Note 2),我们分别生成了一个(绿色)和(蓝色)对的回归图。如果我们不考虑数据截尾,那么就低估了模型的斜率。



图 1. 这个散点图描述了实际的底层事件时间和相对于预测器的观察到的事件时间。

构建模型

借助这些新鲜但截尾的数据,我们可以开始构建更精确的模型。让我们从下面的模型函数开始:


def model(x, y, truncation_label): ## Note [1]   a_model = pyro.sample("a_model", dist.Normal(0, 10)) ## Note [2]   b_model = pyro.sample("b_model", dist.Normal(0, 10))    link = torch.nn.functional.softplus(a_model * x + b_model) ## Note [3]     for i in range(len(x)):    y_hidden_dist = dist.Exponential(1 / link[i]) ## Note [4]         if truncation_label[i] == 0:       ## Note [5]       y_real = pyro.sample("obs_{}".format(i),                            y_hidden_dist,                           obs = y[i])    else:      ## Note [6]      truncation_prob = 1 - y_hidden_dist.cdf(y[i])      pyro.sample("truncation_label_{}".format(i),                   dist.Bernoulli(truncation_prob),                   obs = truncation_label[i])
复制代码


在上面的代码段中,我们重点解释以下注释,以更好地阐明我们的示例:


  • Note 1:总的来说,模型函数描述的是数据生成的过程。这个示例模型函数告诉我们如何从输入的矢量 x 生成 y 或 truncation_label。

  • Note 2:我们指定这里和的先验分布,并利用 pyro.sample 函数对它们采样。Pyro 在 PyTorch 项目和 Pyro 项目中都有大量的随机分布。

  • Note 3: 在这里,我们把输入,和接入用变量 link 表示的矢量。

  • Note 4:我们利用带有尺度参数矢量链接的指数分布来指定真实时间-事件的分布。

  • Note 5:对于观察 i,如果我们观察到时间-事件数据,那么我们把它和实际观察 y[i]进行对比。

  • Note 6:如果对于观察,数据是截尾的,那么截断标签(这里等于 1)遵循伯努利分布。在点,观察到截断数据的概率是的 CDF。我们从伯努利分布中采样,并将其与 truncation_label[i]的实际观察结果进行对比。


有关贝叶斯建模和使用 Pyro 的更多信息,请参考我们的入门教程

用哈密顿•蒙特•卡罗方法(Hamiltonian Monte Carlo,简称 HMC)计算推理

在计算贝叶斯推理时,哈密顿•蒙特•卡罗方法是一种常用的方法。我们用 HMC 来估计 a 和 b,如下所示:


pyro.clear_param_store()
# note [1] hmc_kernel = HMC(model, step_size = 0.1, num_steps = 4)

# Note [2] mcmc_run = MCMC(hmc_kernel, num_samples=5, warmup_steps=1).run(x, y, truncation_label)

# Note [3] marginal_a = EmpiricalMarginal(mcmc_run, sites="a_model")

# Note [4] posterior_a = [marginal_a.sample() for i in range(50)]
sns.distplot(posterior_a)
复制代码


上述过程可能需要很长时间来运行。这么慢的主要原因是,我们需要通过依次观察来评估模型。为了加速该模型,我们可以用pyro.platepyro.mask进行矢量化,如下所示:


def model(x, y, truncation_label):  a_model = pyro.sample("a_model", dist.Normal(0, 10))   b_model = pyro.sample("b_model", dist.Normal(0, 10))    link = torch.nn.functional.softplus(a_model * x + b_model)     with pyro.plate("data"):    y_hidden_dist = dist.Exponential(1 / link)         with pyro.poutine.mask(mask = (truncation_label == 0)):       pyro.sample("obs", y_hidden_dist,                  obs = y)          with pyro.poutine.mask(mask = (truncation_label == 1)):      truncation_prob = 1 - y_hidden_dist.cdf(y)      pyro.sample("truncation_label",                   dist.Bernoulli(truncation_prob),                   obs = torch.tensor(1.))

复制代码


在上面的代码段中,我们首先使用指定的模型来指定 HMC 内核。然后,我们对 x,y 和 truncation_label 执行 MCMC。接着,将 MCMC 采样的结果对象转换为 EmpiricalMarginal 对象,以帮助我们根据 a_model 参数进行推理。最终,我们从后验分布采样,并利用我们的数据绘制出一张图,如下所示:



图 2:a 的采样值直方图。


我们可以看到,这些样本集中在实际值 2.0 附近。

利用变分推理加速估计

随机变分推理(Stochastic variational inference,简称SVI)是利用大量数据加速贝叶斯推理的好方法。现在,我们只需要知道导函数是期望后验分布的近似即可。导函数的指定可以大大加快参数的估计。为了实现随机变分推理,我们定义导函数为:


guide = AutoMultivariateNormal(model)

复制代码


通过使用导函数,我们可以把参数 a 和 b 的后验分布近似为正态分布,其中它们的位置和尺度参数分别由内部参数指定。

训练模型并推断结果

用 Pyro 训练模型的过程和深度学习中的标准迭代优化类似。下面,我们指定 SVI 训练器并通过优化步骤进行迭代:


pyro.clear_param_store()  adam_params = {"lr": 0.01, "betas": (0.90, 0.999)}optimizer = optim.Adam(adam_params)
svi = infer.SVI(model, guide, optimizer, loss=infer.Trace_ELBO())
losses = []for i in range(5000): loss = svi.step(x, y_obs, truncation_label) losses.append(loss)
if i % 1000 == 0: print(', '.join(['{} = {}'.format(*kv) for kv in guide.median().items()]))
print('final result:')for kv in sorted(guide.median().items()): print('median {} = {}'.format(*kv))
复制代码


如果一切如计划所愿,那么我们可以看到上述代码的执行结果。在本例中,我们得到的结果如下,其均值与实际的值及指定的值非常接近:


a_model = 0.009999999776482582, b_model = 0.009999999776482582a_model = 0.8184720873832703, b_model = 2.8127853870391846a_model = 1.3366154432296753, b_model = 3.5597035884857178a_model = 1.7028049230575562, b_model = 3.860581874847412a_model = 1.9031578302383423, b_model = 3.9552347660064697final result:median a_model = 1.9155923128128052median b_model = 3.9299516677856445
复制代码


我们还可以检查模型是否通过下面的代码聚合,并得到图 3,如下所示:


sns.plt.plot(losses)
复制代码



图 3:针对迭代次数绘制的模型损失


我们可以使用guide.quantiles()函数来绘制近似后验分布:


N = 1000for name, quantiles in guide.quantiles(torch.arange(0., N) / N).items():  quantiles = np.array(quantiles)  pdf = 1 / (quantiles[1:] - quantiles[:-1]) / N  x = (quantiles[1:] + quantiles[:-1]) / 2  sns.plt.plot(x, pdf, label=name)  sns.plt.legend()sns.plt.ylabel('density')
复制代码


我们可以看到,导函数分别集中于和的实际值附近,如下所示:


其他

我们希望读者在自己的截尾时间-事件数据建模上试试 Pyro。关于如何开始使用该开源软件,请参考Pyro的官方网站,以获得其它示例,包括入门教程沙箱库


阅读英文原文:Modeling Censored Time-to-Event Data Using Pyro, an Open Source Probabilistic Programming Language,


https://eng.uber.com/modeling-censored-time-to-event-data-using-pyro/


2019-06-08 08:006796
用户头像

发布了 199 篇内容, 共 81.7 次阅读, 收获喜欢 293 次。

关注

评论

发布
暂无评论
发现更多内容

彻底告别传统FTP,新的替代FTP产品比你想象的好的多

镭速

传输协议 FTP传输替代方案

活动预告 | 中国数据库联盟(ACDU)中国行第三站定档成都,邀您探讨数据库前沿技术

墨天轮

MySQL 数据库 oracle postgresql zabbix

Generative AI 新世界 | 扩散模型原理的代码实践之采样篇

亚马逊云科技 (Amazon Web Services)

机器学习 #人工智能 生成式人工智能 Amazon SageMaker 大语言模型

轻量应用服务器,助力个人开发者最低成本创业

YG科技

上新啦!腾讯云云原生数据湖产品DLC 2.2.5版本发布,来看特性详解

腾讯云大数据

数据湖

从技术创新到应用实践,百度智能云发起大模型平台应用开发挑战赛!

不叫猫先生

百度智能云 千帆大模型平台

企业内部通讯,WorkPlus助您打造高效沟通平台

WorkPlus

Github上线即遭狂转!上百人通过这份算法手抄本成功上岸字节

程序员万金游

#java java 架构 #算法 #数据结构 #java编程

华为云耀云服务器L实例:带你探索轻量应用服务器的魅力

YG科技

用友系列之 YonBuilder 低代码平台概论和基本使用

YonBuilder低代码开发平台

低代码 可视化

WorkPlus私有化部署IM即时通讯平台,构建高效安全的局域网办公环境

WorkPlus

rabbitMQ到底是个啥东西?

程序员万金游

Java 开发 #java Rabbit MQ

Redis内存碎片:深度解析与优化策略

Java随想录

Java redis

福利贴|这是一个程序员不看一定会后悔的问题

Zilliz

非结构化数据 Milvus Zilliz 向量数据库

华为智慧屏,吹尽狂沙始到金

脑极体

AI智慧屏

腾讯Java后端社招三面,差点就挂了!

程序员小毕

Java 程序员 面试 程序人生 架构师

HarmonyOS使用多线程并发能力开发

HarmonyOS开发者

HarmonyOS

使用Optional优雅避免空指针异常

Java随想录

Java 异常

IoTDB 在国际数据库性能测试排行榜中位居第一?测试环境复现与流程详解第一弹!

Apache IoTDB

TinyEngine 低代码引擎到底是什么?

英勇无比的消炎药

开源 前端 低代码

华为阅读“鲁迅专栏”已上线,读国内名家作品就上华为阅读

最新动态

火山引擎边缘云:数智化项目管理助力下的业务增长引擎

火山引擎边缘云

数字化 飞书 数智化 #项目管理

WorkPlus Meet 视频会议,自主可控,支持私有化部署

WorkPlus

Redis类型(Type)与编码(Encoding)

Java随想录

redis

为什么要使用zookeeper

Jerry Tse

zookeeper 分布式锁 分布式系统 共识算法 数据强一致性

close()关闭文件方法

智趣匠

大道总是孤独的——查理芒格如是说

少油少糖八分饱

投资 长期主义 能力圈 查理芒格 股东大会

ICCV 2023|小红书 4 篇入选论文亮点解读,「开集视频目标分割」获得 Oral

小红书技术REDtech

算法 ICCV

跨网传输文件时,如何通过日志记录来审计追溯?

镭速

跨网文件传输

Java训练营毕业总结

jjn0703

Python 中的数字类型与转换技巧

小万哥

Python 程序员 软件 后端 开发

使用开源概率编程语言Pyro对截尾时间-事件数据进行建模_文化 & 方法_Hesen Peng_InfoQ精选文章