JAX vs PyTorch: A simple transformer benchmark

Posted on September 6, 2021

I’ve been looking into deep learning libraries recently and JAX seemed interesting, but as far as I could tell no one had actually benchmarked it against PyTorch, the de facto standard. So I decided to implement the same model in both and compare. Here’s the top level summary: PyTorch gets 1.11 iterations per second and JAX gets 1.24it/s (12% better) on a Google Colab notebook with a P100. In addition, JAX is more memory efficient, the PyTorch model OOMs with more than 62 examples at a time and JAX can get up to 79 (at 1.01it/s, or 79.79 examples per second vs PyTorch’s 68.82 with the smaller batch size).

Meanwhile TPUs are kind of absurd. Torch on XLA theoretically exists, but I don’t know of anyone who’s actually gotten it to work. When I was testing it my code segfaulted. TPUs work very smoothing with JAX though. I was accepted in the TPU Research Cloud (formerly TFRC), and a TPUv3-8 can run through 2,591 examples per second with a batch size of 3,032.

Benchmark details

You can reproduce my GPU results using this notebook and find the model code here. The TPU code is in the pmap branch. Unfortunately Colab TPUs are flaky so there’s no notebook for that. The model is a simple, byte-level, autoregressive transformer language model trained on enwik9. I used the Flax neural net framework for the JAX implementation. The hyperparameters are as follows:

Parameter Value
layers 12
d_model 512
heads 8
feedforward dimension 3072
sequence length 256

It’s GPT-1 with embedding dimension 512 instead of 768. Quite small in comparison to SOTA models.


This is, obviously, a single measurement. The comparison, and the direction of the advantage may vary by model type, size, hardware, and other factors. I’m not making any universal statements here. Furthermore, 12% better performance isn’t much. Competent ML engineers are expensive (and your time is valuable) - it’s easily possible that you lose more in engineering time than you gain in training time. And it’s always possible I’ve made a mistake and the two models aren’t actually identical.

Observations about programming in the two systems

I haven’t done a ton of ML programming in either Torch or JAX/Flax, but I can compare what I do know of them.

  1. Torch is much more batteries-included. There’s a TransformerEncoder and a TransformerEncoderLayer in torch.nn. In Flax, there’s an attention module, but the transformer assembly - attention + feedforward + layer norm + residuals - I had to write myself.
  2. vmap is very cool. Briefly, vmap allows you to turn any function into a vectorized version, and JAX will generate efficient code for you. So I could get rid of batch dimensions everywhere except the outermost layer, and give myself one less thing to get wrong.
  3. pmap is very cool too. Analogous to vmap, it lets you parallelize code across multiple accelerators and across multiple host machines, provided the accelerators have a special cluster setup for fast interconnect. In practice I think that mostly means TPU pods, though they do mention a way to do it with Nvidia GPUs.
  4. TPUs are really really powerful. Good TPU support, especially since I have access to TRC, makes the choice easy.
  5. All the indirection that Flax introduces to let you write models in an object oriented style makes the stack traces really bad. They’re like 80% stuff internal to Flax or to JAX’s JIT.
  6. JAX’s approach to differentiation is more powerful and less footgunny than Torch’s. It’s not possible to accidentally compute useless gradients or accidentally modify things that shouldn’t be learned parameters.
  7. Performance debugging is easier with Torch. If you use the profiler on a JAX program, everything that’s been JIT compiled shows up as “custom-call” or “fused”, and the JIT compiled code is all of the code who’s performance you care about. Apparently it works if you use the special secret profiler Google has internally.
  8. Being a much less used library, it’s much harder to Google error messages and the like.


I like JAX, and I intend to use it for my next big project (a CLIP conditioned image generation model). But if TPUs and especially the TRC didn’t exist, I’m not sure it’d be worth it.


I let the model train for around four days on a TPUv3-8. I was surprised by how well it works. Note that Wikipedia uses triple single quotes for bold and double single quotes for italics. All article ledes include the subject in bold.

Prompt Sample
'''Star Trek''' '''Star Trek''' is a fictional [[supervillain]] of a [[fictional character]], a male [[antagonist]], and a supervillain of a supervillain [[animation|animated]] [[science fiction]] [[television series]]. One time writer [[Andrew Stewart]] used Star Trek to
'''Star Trek''' '''Star Trek''''') is a [[comic book]] series continuing as a new [[1990s]] and [[1992s]] [[comic book]] character from [[Tony Straight]]. It is one of the oldest programs in the series, played by the [[Halloween]] television series ''[[Doctor Who]]''. The
'''Star Trek''' '''Star Trek''''' is a series of series produced by [[Wizards of the Coast]] featuring several stories and endings. These combine to form the novel ''[[What's New, Purgatory?]]'' and its musical numbers. ''[[The Whales of Magellan]]'' is a [[science fictio
'''Star Trek''' '''Star Trek''' or '''Kazna''' which literally means "childhood canal". Star characters were either [[warp drive]]s or [[computer-generated imagery|scale control video]]s are the primary weapons in the series. The series was premiered in [[2002]]
'''San Francisco''' '''San Francisco''' is the name of many attractions situated on [[San Francisco International Airport]]. It is one of the few free airports located near [[Panama City, Florida]].
== History ==
Stanford was founded in 1918 as the home of The San Francisc
'''San Francisco''' '''San Francisco'''. After the [[Mexican-American War]] the seaport developed into the seat of the city of [[Rio Grande, California|Rio Grande]]. Passenger service was directed to [[New York City]] by surveyor San Francisco Parks Corporation. The passenger
'''San Francisco''' '''San Francisco''', named after the San Gabriel [[mariage]] and [[Irish Catholic]] [[eschatology]] founded in 1863 by San Gabriel (Redfern) was named in honor of ''Cestion San Francisco'' (a term which the reputed early mariage was held up by [[Native Ame
'''San Francisco''' '''San Francisco''' (born '''Mark Antonio Baldwin''' [[September 2]], [[1945]]) is a [[Canada|Canadian]] [[Public house|pub]] owner and legend of [[Uburban Culture]] [[Public house|pubs]].

Baldwin started his own business in [[1963]] when he left to sett
'''George Walker Bush''' '''George Walker Bush''' (born [[July 13]], [[1954]]) is an [[United States|American]] [[physicist]] and [[Nobel Prize]] winner. He was born in [[Albany, New York]].

Born '''George Lauder-Freiherr Bush''' (born [[March 10]], [[1957]]) he became a member o
'''George Walker Bush''' '''George Walker Bush''' (born [[July 7]], [[1961]]) is an American [[philanthropist]] who at one time secured a record of 3 works before attending the [[Carnegie Institute of Technology]] and became a full-time journalist in [[1994]].

Born in [[Frankfor
'''George Walker Bush''' '''George Walker Bush''', [[United States Republican Party|Republican]] ([[Democratic Party (United States)|Democrat]])
* '''George Mills''', [[United States Democratic Party|Democrat]] ([[Democratic Party (United States)|Democrat]])
* '''[[Anthony Burrows

It doesn’t seem to know what Star Trek or San Francisco are, or who George W Bush is, but it does associate Star Trek with nerdy entertainment, television, and warp drive. Similarly, it associates SF with SFO, Stanford, and California. It seems to know, at least sometimes, that George W Bush is associated with US politics as well. And it’s learned what the ledes to biographies look like.

comments powered by Disqus