Scaling Deep Learning Models: Parallelism Strategies in JAX with Flax
Description
En esta sesión de 30 minutos en PyCon Colombia, exploraremos estrategias clave de paralelismo en JAX con Flax para el entrenamiento eficiente de modelos de aprendizaje profundo a gran escala. Comenzaremos examinando la necesidad y la importancia del escalado en el aprendizaje profundo, así como las ventajas del paralelismo para aprovechar eficazmente el hardware disponible. Luego, nos sumergiremos en las técnicas de paralelismo, desde la optimización en una sola GPU hasta la escalabilidad a múltiples GPU o TPU. A lo largo de la charla, presentaré ejemplos prácticos y demostraciones en vivo, guiando a los asistentes a través de la implementación de estrategias de paralelismo utilizando shard_map en JAX con Flax. Desde la comprensión de los conceptos básicos hasta la aplicación práctica en casos reales, los participantes aprenderán a optimizar el entrenamiento de modelos de aprendizaje profundo de manera eficiente y efectiva. Esta charla está diseñada para ser accesible para una amplia audiencia, desde aquellos nuevos en el campo hasta expertos en aprendizaje profundo. Al final de la sesión, los asistentes habrán adquirido un conocimiento práctico que podrán aplicar en sus propios proyectos de forma inmediata.