FlashAttention hecho a mano: una guía práctica y en español para entender cómo acelerar la autoatención exacta con conciencia de memoria
Introducción. FlashAttention parte de la idea de softmax en línea para crear un algoritmo de pasada única que calcula directamente O = softmax(QK^T)V sin materializar la matriz de atención A de tamaño L×L. La clave es fusionar en un mismo kernel de GPU los productos de matrices QK^T, la aplicación de softmax en línea y la acumulación con V. Al evitar leer y escribir A en memoria global (HBM/DRAM), se reduce drásticamente el tráfico de memoria, que es el verdadero cuello de botella, consiguiendo más velocidad y menor uso de memoria, sobre todo en secuencias largas.
Autoatención clásica sin FlashAttention. El flujo estándar es: 1) Logits X = QK^T; 2) A = softmax(X); 3) O = A·V. Ejemplo 1×6: supongamos X = [1, 2, 3, 6, 2, 1] y V con 6 filas y dimensión 2, donde v1=[1,1], v2=[2,2], …, v6=[6,6]. Con el softmax estable restamos el máximo m=6: X-m = [-5, -4, -3, 0, -4, -5]; exponenciamos e^(X-m) ˜ [0.0067, 0.0183, 0.0498, 1, 0.0183, 0.0067], sumamos d ˜ 1.0998 y normalizamos A ˜ [0.0061, 0.0167, 0.0453, 0.9092, 0.0167, 0.0061]. Finalmente, O = A·V ˜ [3.932, 3.932]. El problema: hay que construir A completo antes de multiplicar por V, lo que implica costosas escrituras y lecturas a HBM.
Visión conceptual de FlashAttention. El algoritmo procesa Q, K y V por bloques o teselas en SRAM, manteniendo tres estadísticas por fila de salida: m_running (máximo acumulado), d_running (denominador acumulado del softmax) y o_running (salida acumulada no normalizada). Para cada bloque i: 1) se calcula x_i = q·K_i^T, 2) se obtiene el máximo local y se actualiza m_running, 3) se computa el aporte local del softmax en línea y se reescala la contribución anterior con e^(m_old - m_new), 4) se actualizan d_running y o_running de inmediato con V_i. Así, no se materializa A y se evita el viaje completo a HBM.
FlashAttention a mano con un ejemplo. Usamos el mismo X y V anteriores y tamaño de tesela 3. Tesela T1: logits [1, 2, 3], máximo local 3. P1 = e^[1-3, 2-3, 3-3] ˜ [0.1353, 0.3679, 1], d_T1 ˜ 1.5032 y o_T1 ˜ 0.1353·[1,1] + 0.3679·[2,2] + 1·[3,3] = [3.8711, 3.8711]. Inicialmente m_running = -8, d_running = 0, o_running = [0,0]. Tras T1: m_running = 3, d_running = 1.5032, o_running = [3.8711, 3.8711]. Tesela T2: logits [6, 2, 1], máximo local 6 y m_new = 6. Reescalamos con e^(3-6) ˜ 0.04979: d_running = 1.5032·0.04979 + 1.025 ˜ 1.0998; o_running = [3.8711, 3.8711]·0.04979 + [4,4] + 0.0183·[5,5] + 0.0067·[6,6] ˜ [4.3244, 4.3244]. Normalización final: O = o_running/d_running ˜ [4.3244, 4.3244]/1.0998 ˜ [3.932, 3.932]. El mismo resultado que el método clásico, pero sin A.
Diagrama y mapeo mental. Piense en dos bucles anidados: el interno recorre filas de Q (una por una en nuestro ejemplo) y el externo recorre bloques de K y V. En cada iteración del bucle externo, se copian a SRAM solo los bloques necesarios desde HBM, se calcula el aporte local con reescalado y se actualiza o_running y d_running en SRAM. Al final, se escribe la salida normalizada a HBM. El beneficio es IO-aware: mover lo mínimo, calcular lo máximo en SRAM.
Apéndice A. De pesos a activaciones: W_Q, W_K, W_V son matrices de pesos entrenables que proyectan los embeddings X a las activaciones Q, K y V: Q = X·W_Q, K = X·W_K, V = X·W_V. Estas activaciones cambian por secuencia y son la entrada real del kernel de atención que FlashAttention reemplaza. En el ejemplo, la primera fila q de Q genera los logits q·K^T = [1, 2, 3, 6, 2, 1] y V es la matriz de valores usada en la acumulación.
Apéndice B. Equivalencia entre normalizar en cada paso y normalizar al final. El pseudocódigo original mantiene una salida normalizada o_prime en cada iteración i usando d_prime como denominador acumulado. La versión manual mantiene o_running y d_running sin normalizar, y divide una sola vez al final. Son equivalentes: multiplicar las fórmulas del paper por d_prime_i muestra que o_running_i = o_prime_i·d_prime_i y que la actualización por bloques con el reescalado e^(m_{i-1} - m_i) preserva la razón correcta en cada paso. Por inducción, o_prime_final = O_final/d_prime_final, por lo que normalizar al final produce exactamente la misma salida.
Por qué funciona mejor. FlashAttention optimiza el patrón de acceso a memoria en GPU: 1) fusiona operaciones en un único kernel, 2) reduce IO con teselado en SRAM, 3) garantiza estabilidad numérica con el máximo en línea y reescalados, 4) evita crear A completa. El resultado práctico: menos latencia, menor uso de memoria y más throughput, especialmente en L grandes y con cabezas múltiples.
Casos de uso y buenas prácticas. 1) Use tamaños de bloque que llenen eficientemente la SRAM sin provocar derrames a HBM. 2) Mantenga el softmax en línea con reescalado por máximo para estabilidad. 3) Integre QK^T, softmax y AV en un kernel para minimizar sincronizaciones. 4) Verifique la equivalencia numérica con la ruta clásica en conjuntos de prueba antes de poner en producción.
Sobre Q2BSTUDIO. Somos una empresa de desarrollo de software con foco en aplicaciones a medida y software a medida, especialistas en inteligencia artificial, ciberseguridad, servicios cloud aws y azure, servicios inteligencia de negocio y power bi, automatización de procesos y agentes IA. Ayudamos a implantar IA para empresas, desde casos de uso con modelos generativos hasta integración MLOps en cloud y gobierno del dato. Si buscas acelerar tus soluciones de NLP, RAG o LLMs con técnicas como FlashAttention, podemos diseñar e implementar el stack óptimo en producción.
Si quieres explorar cómo la inteligencia artificial puede transformar tus productos digitales, descubre nuestros servicios de inteligencia artificial e IA para empresas. Y si necesitas una plataforma robusta y escalable, también podemos ayudarte con servicios cloud AWS y Azure para desplegar y monitorizar tus cargas de trabajo de datos y modelos a gran escala.
Palabras clave para orientar tu estrategia: aplicaciones a medida, software a medida, inteligencia artificial, ciberseguridad, servicios cloud aws y azure, servicios inteligencia de negocio, power bi, ia para empresas, agentes IA, automatización de procesos, FlashAttention, softmax en línea, atención exacta, GPU kernel fusion, optimización IO-aware.