- Google Maps Platform
- Web Push and Notification APIs
- Google Ads API
- Google Play Billing
- Interactive Media Ads
- Google Developer Groups
- Google Developer Student Clubs
- Google Developer Experts
- Tech Equity Collective
- Google Maps Platform
- Web Push and Notification APIs
- Google Ads API
- Google Play Billing
- Interactive Media Ads
- Google Developer Groups
- Google Developer Student Clubs
- Google Developer Experts
- Tech Equity Collective
Introducing Metrax: performant, efficient, and robust model evaluation metrics in JAX
NOV. 13, 2025 Yufeng Guo Developer Advocate Core ML Frameworks Jiwon Shin Software Engineer Core ML Frameworks Jeff Carpenter Software Engineer Core ML Frameworks
At Google, as teams were migrating from TensorFlow to JAX, teams were manually reimplementing metrics that were previously provided by TensorFlow, because JAX did not have a built-in metrics library. So each team using JAX was implementing its own version of accuracy, F1, RMS error, etc. While creating metrics may seem, to some, like a fairly simple and straightforward topic, when considering large scale training and evaluation across datacenter-sized distributed compute environments, it becomes somewhat less trivial.
And thus the idea for Metrax was born: to bring a high-performance library for efficient and robust model evaluation metrics in JAX. Metrax currently provides predefined metrics used to evaluate various types of machine learning models (classification, regression, recommendation, vision, audio, and language), and provides compatibility and consistency in distributed and scaled training environments. This allows you to focus on the model evaluation results, rather than (re)implementing various metrics definitions. Metrax adds to the ever-evolving ecosystem of JAX-based tooling, integrating well with the JAX AI Stack, a suite of tools that are designed to work together to power your AI tooling needs. Today, Metrax is already used by some of the largest software stacks at Google, including teams in Google Search, YouTube, and Google’s own post-training library, Tunix.
Link to Youtube Video (visible only when JS is disabled)
Strengths of Metrax
Particularly noteworthy is the inclusion of the ability to compute "at K" metrics for multiple values of K, in parallel, which allows you to more comprehensively evaluate model performance, more quickly. For example, you can use PrecisionAtK to determine the precision of your model for multiple values of K (say, at K=1, K=8, and K=20), all in one forward pass through your model, rather than needing to call PrecisionAtK multiple times with each of these arguments. There are several "at K" metrics available for you to try out, including RecallAtK and NDCGAtK. All the metrics, along with their definitions, can be found at the documentation located here.
The last thing you want to worry about when working on your machine learning research project is whether your metrics are implemented correctly across your system, so having a well-tested metrics library will help the community create less error prone code and model evaluations.
Performance
Metrax leverages some of the core strengths of JAX, including vmap and jit, to enable it to do things like multiple "at K" operations, and to do so in a highly performant manner. While not every metric offered is "jit-able" due to the nature of the metric, the goal is to ensure all metrics are well-written and demonstrate best practices. Beyond the classic metrics such as accuracy, precision, and recall, the library also features a robust set of NLP-related metrics, including Perplexity, BLEU, and ROUGE, as well as metrics for vision models, such as Intersection over Union (IoU), Signal-to-Noise Ratio (SNR), and Structural Similarity Index (SSIM). There's no need to vibe code your metrics implementations anymore, just use Metrax!
Metrax in action
Let's see how to use Metrax with your code. This is what it looks like to compute precision metrics from your model's output. Notice that we pass in the predictions and labels, along with a threshold value, and then to compute the metric's value, we need to call compute().
import metrax # Directly compute the metric state. metric_state = metrax.Precision.from_model_output( predictions=predictions, labels=labels, threshold=0.5 ) # The result is then readily available by calling compute(). result = metric_state.compute() result Python CopiedOftentimes, we do evaluations in batches, so we want to be able to iteratively add more information to our collection of metrics. Metrax supports this workflow with a function called merge(). This is a great function to use inside your evaluation loop as you're aggregating your metrics over the course of your training run. Notice we still call compute() when we're ready to get a final value.
# Iteratively merging precision metrics for labels_b, predictions_b, weights_b in zip(labels_batched, predictions_batched, sample_weights_batched): batch_metric_state = metrax.Precision.from_model_output( predictions=predictions_b, labels=labels_b ) metric_state = metric_state.merge(batch_metric_state) result = metric_state.compute() result Python CopiedFor a full set of examples check out this notebook, which demonstrates more ways you can use Metrax, including scaling to multiple devices and integrations with Flax NNX, a modeling library that abstracts away some of the implementation details of building AI models.
Contribute
Metrax is developed on GitHub, and happy to accept community contributions. Some of the metrics available today were in fact added by community contributors; big shout out to GitHub users @nikolasavic3 and @Mrigankkh for their efforts! So if there are more metrics you'd like to see added to it, submit a pull request and work with the development team to include it into Metrax. You can learn more at github.com/google/metrax.
Also, be sure to check out the other libraries in the JAX ecosystem, at jaxstack.ai. There you can find more libraries that integrate well with Metrax, and additional content about building machine learning models.
posted in: Previous Next Related Posts
AI
Announcements
Announcing the Data Commons Gemini CLI extension
DEC. 2, 2025
AI
Cloud
How-To Guides
Solutions
Don't Trust, Verify: Building End-to-End Confidential Applications on Google Cloud
DEC. 9, 2025
Mobile
AI
Announcements
Industry Trends
MediaTek NPU and LiteRT: Powering the next generation of on-device AI
DEC. 8, 2025 Programs- Google Developer Program
- Google Developer Groups
- Google Developer Experts
- Google Cloud & NVIDIA
- Google API Console
- Google Cloud Platform Console
- Google Play Console
- Actions on Google Console
- Cast SDK Developer Console
- Chrome Web Store Dashboard
- Google Home Developer Console
- Google Cloud Platform
Login to comment
To post a comment, you must be logged in. Please login. Login
Comments (0)