1,7-szeres gyorsulást hoz a Lighthouse Attention a hosszú kontextusú LLM-ek betanításában
A Lighthouse Attention a betanítási időt csökkenti, miközben a modellek végső betanítási veszteségét változatlanul hagyja, vagy akár javítja is.

A nagyméretű nyelvi modellek (LLM) hosszú szekvenciákon történő betanítása jelentős kihívást jelent, mivel a transzformátorok alapját képező skálázott dot-product figyelem (SDPA) számítási és memóriaigénye is kvadratikusan növekszik a szekvenciahosszal. Ezt a szűk keresztmetszetet oldja meg a Nous Research kutatócsoportjának új módszere, a Lighthouse Attention, amely 1,40-szerestől 1,69-szeresig terjedő végpontok közötti sebességnövekedést biztosít a betanítás során egy cuDNN-alapú SDPA referenciamodellhez képest — írja a MarkTechPost.
A Lighthouse Attention a betanítási időszakban aktív, utána eltávolítható, így az inferencia során a modell továbbra is sűrű figyelemmel működik. A Nous Research szerint a módszer kezeli azt a kritikus kérdést, hogy a betanítás után az eredményül kapott súlyok továbbra is kompetens sűrű figyelemmodellt fognak-e produkálni az inferencia során.
A figyelem fókuszában
A legtöbb korábbi ritka figyelemmechanizmus, mint a NSA, HISA, DSA vagy MoBA, aszimmetrikusan tömöríti a kulcs- és értékoldalt, miközben a lekérdezéseket teljes felbontásban hagyja. Emellett a szelekciós logikájuk egyedi figyelem kernelen belül helyezkedik el, ami megakadályozza az optimalizált sűrű figyelemkernelek újrahasznosítását.
A Lighthouse Attention ezzel szemben szimmetrikusan kezeli a lekérdezéseket, kulcsokat és értékeket egy többszintű piramisban. A szelekció teljes egészében a figyelem kernelen kívül történik, ami lehetővé teszi a meglévő, optimalizált FlashAttention kernelek használatát. A kiválasztott bejegyzéseket egy összefüggő, sűrű al-szekvenciába gyűjti a rendszer, majd ezen futtatja a standard FlashAttentiont.
Sebesség és hatékonyság
A Lighthouse figyelemréteg négy szakaszból álló folyamatban működik. Először átlagoló poololással épít fel egy L-szintű piramist a Q, K és V értékekből. Második lépésben egy paramétermentes pontozó két skalár pontszámot rendel minden piramisbejegyzéshez, fej-specifikus ℓ₂ normák alapján. A harmadik szakaszban a kiválasztott bejegyzéseket egy összefüggő al-szekvenciába gyűjtik, majd ezen futtatják a FlashAttentiont. Végül a negyedik szakaszban minden kimeneti bejegyzést visszaszórnak az általa reprezentált alap pozíciókba egy determinisztikus, egész szám-atomikus szórási kernellel.
Ez a szimmetrikus poololási megközelítés O(N Sd) helyett O(S² d) számítási komplexitást eredményez a betanítás során. Mivel hosszú kontextusok esetén S jóval kisebb, mint N, ez adja a késleltetési előnyt. Egyetlen NVIDIA B200 GPU-n, 512K kontextussal (bfloat16, B=1, H=8, fejdimenzió 128, L=3, p=4, ritkaság ≈ 1:64) a Lighthouse 21-szer gyorsabb az előremenő és 17,3-szer gyorsabb az előre- és hátramenő passz kombinációjában a cuDNN-alapú SDPA-hoz képest. A kutatás az arXiv:2605.06554 számon jelent meg 2023. május 12-én.