Is your feature request related to a problem? Please describe.
Yes, the current implementation of the DilatedAttention and FlashAttention modules in the Zeta repository does not support multi-GPU configurations effectively, particularly lacking in model parallelism and data parallelism capabilities. Specifically, FlashAttention is optimized for A100 GPUs, but I am equipped with 8 A10 GPUs and would like to leverage all available resources efficiently. This limitation restricts the scalability and speed of my deep learning tasks, particularly for large-scale sequence processing and attention mechanisms.
Describe the solution you'd like
I propose enhancing the DilatedAttention and FlashAttention classes to include support for both model parallelism and data parallelism. This update should include:
Describe alternatives you've considered
An alternative could be the manual partitioning of tasks and managing CUDA devices at the application level, but this approach is less efficient and scalable. Utilizing existing frameworks like NVIDIA’s NCCL for communication in parallel processing might be considered if native support in the framework proves too complex to implement in the initial stages.
Pay now to fund the work behind this issue.
Get updates on progress being made.
Maintainer is rewarded once the issue is completed.
You're funding impactful open source efforts
You want to contribute to this effort
You want to get funding like this too