谷歌 Metrax 为 JAX 引入了预定义的模型评估指标

  • 2025-12-24
    北京
  • 本文字数:1011 字

    阅读完需:约 3 分钟

Metrax是一个JAX库,最近由谷歌开源,为分类、回归、自然语言处理(NLP)、视觉和音频模型提供了标准化的性能指标实现。

 

谷歌解释说,Metrax解决了 JAX 生态系统中的一个空白,这个空白迫使许多团队从 TensorFlow 迁移到 JAX,以实现他们自己的通用评估指标版本,如准确性、F1 分数、RMS 误差等:

 

虽然在某些人看来,创建指标似乎是一个相当简单和直接的话题,但当考虑到跨数据中心规模的分布式计算环境中的大规模训练和评估时,它就变得不那么简单了。

 

Metrax 为一系列机器学习模型提供了预定义的评估度量指标,包括分类、回归、推荐、视觉和音频,特别支持分布式和大规模的训练环境。对于视觉模型,该库包括诸如交并比(IoU)、信噪比(SNR)和结构相似性指数(SSIM)等指标,Metrax 还包括鲁棒的 NLP 相关度量指标,包括困惑度(Perplexity)、BLEU 和 ROUGE。

 

谷歌指出,Metrax 的目标之一是确保所有度量指标都得到很好的实施并遵循最佳实践。在度量指标定义支持的地方,Metrax 使用 JAX 的高级功能,如 vmap 和 jit 来提高性能。例如,这些特性用于实现新的“at K”指标,以支持并行计算多个 K 值。这使我们能够更全面、更快地评估模型。

 

你可以使用 PrecisionAtK 来确定多个 K 值(比如 K=1、K=8 和 K=20)下模型的精度,所有这些都是在模型的一次前向传递中进行的,而不需要对每个参数多次调用 PrecisionAtK。

 

名为Neural Foundry的DevOps工程师在Substack上写道

 

Metrax 支持在单次传递中计算多个 K 值,这对排名系统来说是一个巨大的胜利。我每次切换项目时都需要重写度量工具,这种标准化早就应该实现了。API 看起来也很干净。好奇他们是否针对特定用例(如大规模推荐管道)的自定义实现进行了基准测试。

 

下面的代码片段展示了如何根据预测结果和标签计算精度度量指标。可以指定一个可选的阈值,将概率预测转换为二元预测:

 

import metrax  # 直接计算度量状态。metric_state = metrax.Precision.from_model_output(    predictions=predictions,    labels=labels,    threshold=0.5)# 然后通过调用compute()即可获得结果。result = metric_state.compute()result
复制代码

 

谷歌还发布了一个笔记本,包含了一系列综合示例,包括多设备扩展和与Flax NNX的集成,Flax NNX 是一个简化的 API,使得在 JAX 中创建、检查、调试和分析神经网络变得更加容易。

 

JAX是一个开源的 Python 库,用于高性能数值计算和机器学习。

 

原文链接:

https://www.infoq.com/news/2025/12/metrax-jax-evaluation-metrics/