The needs for data parallel attention emerged with DeepSeek’s MLA attention mechanism.
Since MLA uses a projected KV cache, thus does should “copy” KV caches for tensor parallelism, which would result in unwanted memory usage.
To overcome this problem, SGLang implmented data parallelism for MLA attention.
flowchart TB
Client[Client] --> HTTP[HTTP Server '/generate']
HTTP -->|GenerateReqInput| TMgr[TokenizerManager]
subgraph "Engine"
TMgr -->|TokenizedGenerateReqInput| DPC[Data Parallel Controller]
DPC -->|Round-robin dispatch| Sched[Scheduler]
subgraph "TP/DP Workers"
Sched -->|Forward batch| Model[Model Inference]
Model -->|Generation result| Sched
end
Sched -->|BatchTokenIDOut| DTok[Detokenizer]
DTok -->|BatchStrOut| TMgr
end
TMgr -->|Response| HTTP
HTTP -->|Text response| Client
We decided to see the code flow of DP attention and its scheduling, to follow what’s going on with data-parallel attention in SGLang.
Let’s start with the entry point of SGLang. SGLang kindly gives you the entry points by setting entry point codes aside python/sglang/srt/entrypoints/*
. There are two implementations here: http_server.py
and engine.py
http_server.py
is obviously an HTTP server that routes reqests to the engine. It uses FastAPI.engine.py
contains the top level implementation of inference engine.The Engine
class in engine.py
has the following description:
class Engine:
"""
The entry point to the inference engine.
- The engine consists of three components:
1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
**2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.**
3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
Note:
1. The HTTP server, Engine, and TokenizerManager both run in the main process.
2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
"""
The engine parses the kwargs, then directly launches subprocesses on initialization.
class Engine:
def __init__(self, **kwargs):
server_args = ServerArgs(**kwargs)
# Launch subprocesses
tokenizer_manager, scheduler_info = **_launch_subprocesses**(
server_args=server_args
)
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
Note that this function is the first place where run_data_parallel_controller_process
is called.
if server_args.dp_size == 1:
# Launch tensor parallel scheduler processes
...
else:
# Launch the data parallel controller
reader, writer = mp.Pipe(duplex=False)
scheduler_pipe_readers = [reader]
proc = mp.Process(
target=**run_data_parallel_controller_process**,
args=(server_args, port_args, writer),
)
proc.start()
scheduler_procs.append(proc)