Par Jean-Eric Campagne.
JAX est une bibliothèque Python conçue pour l'apprentissage automatique haute performance, mêlant NumPy, Autograd (différenciation automatique) et XLA (Accelerated Linear Algebra).
Il existe bon nombre de tutoriels pour se mettre à JAX. Le parti pris de celui-ci est de faire un apprentissage progressif à travers des exemples tirés de la pratique de l'orateur durant les 3 dernières années. Nous verrons notamment des codes qui "crash"ent pour savoir comment y remédier.
Les notebooks sont jouables sur Colab avec une version à jour de JAX 0.4.24 (ou plus), d'abord sur CPU puis dans certain cas nous utiliserons 1 GPU, voir plusieurs sur Jean-Zay. Les notebooks sont disponibles sur un dépôt Github.
Nb : ce n'est pas un cours de Machine Learning.