-
Notifications
You must be signed in to change notification settings - Fork 6
compute max_new_tokens for dcp to support bigger batch size #19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: yjh/dcp-dev
Are you sure you want to change the base?
Conversation
Signed-off-by: augusto.yjh <[email protected]>
Reviewer's guide (collapsed on small PRs)Reviewer's GuideThis PR introduces a helper to compute per-rank max_new_tokens for DCP and wires it into the scheduling policy so token budgets, prefill budgets, and preemption calculations are based on the distributed world size, enabling larger effective batch sizes per node. Updated class diagram for DCP-aware scheduling policyclassDiagram
class SchedulePolicy {
float new_token_ratio
int rem_total_tokens
int rem_chunk_tokens
int _get_running_request_total_token_offset(req: Req)
void add_chunked_req(req: Req)
void add_req_state(r: Req, insert_sort: bool)
void add_one_req(req: Req, has_chunked_req: bool)
bool preempt_to_schedule(req: Req, server_args: ServerArgs)
}
class Req {
SamplingParams sampling_params
list[int] output_ids
int extend_input_len
list[int] origin_input_ids
}
class SamplingParams {
int max_new_tokens
bool ignore_eos
}
class ServerArgs {
}
%% ...
class DCPParallelState {
int get_dcp_world_size()
}
class ScheduleHelpers {
int compute_dcp_local_max_new_tokens(tokens: int)
}
SchedulePolicy --> Req : uses
Req --> SamplingParams : has
SchedulePolicy --> ServerArgs : uses
SchedulePolicy ..> ScheduleHelpers : calls
ScheduleHelpers ..> DCPParallelState : calls
Flow diagram for DCP-local max_new_tokens computation in schedulingflowchart TD
A["Receive request with global max_new_tokens from SamplingParams"]
B["Call compute_dcp_local_max_new_tokens(max_new_tokens)"]
C["Inside compute_dcp_local_max_new_tokens: world_size = get_dcp_world_size()"]
D["Compute local_max = (tokens + world_size - 1) // world_size"]
E["Return local_max to scheduling policy"]
F["Clip local_max with CLIP_MAX_NEW_TOKENS"]
G["Use clipped local_max in token budget, prefill budget, and preemption calculations"]
A --> B
B --> C
C --> D
D --> E
E --> F
F --> G
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
Summary of ChangesHello @staugust, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the SGLang scheduling policy to better support distributed collective parallelism (DCP) by adjusting how Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey there - I've reviewed your changes - here's some feedback:
- Consider clamping non-positive values in compute_dcp_local_max_new_tokens (e.g., max(tokens, 0)) before applying the division to avoid surprising negative results when max_new_tokens - len(output_ids) is negative in edge cases.
- compute_dcp_local_max_new_tokens is called multiple times in hot scheduling paths; you might want to cache the DCP world size (or the local max_new_tokens) instead of recomputing it on every call to reduce overhead.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- Consider clamping non-positive values in compute_dcp_local_max_new_tokens (e.g., max(tokens, 0)) before applying the division to avoid surprising negative results when max_new_tokens - len(output_ids) is negative in edge cases.
- compute_dcp_local_max_new_tokens is called multiple times in hot scheduling paths; you might want to cache the DCP world size (or the local max_new_tokens) instead of recomputing it on every call to reduce overhead.
## Individual Comments
### Comment 1
<location> `python/sglang/srt/managers/schedule_policy.py:34-35` </location>
<code_context>
from sglang.srt.server_args import ServerArgs
+from sglang.srt.distributed.parallel_state import get_dcp_world_size
+
+def compute_dcp_local_max_new_tokens(tokens: int):
+ return (tokens + get_dcp_world_size() -1) // get_dcp_world_size()
+
</code_context>
<issue_to_address>
**issue (bug_risk):** Guard against negative token counts passed into compute_dcp_local_max_new_tokens.
Some callers can pass a negative value here (e.g., `req.sampling_params.max_new_tokens - len(req.output_ids)` once `len(req.output_ids)` exceeds `max_new_tokens`). That makes `(tokens + world_size - 1) // world_size` negative and propagates bad values into scheduling/offset calculations. Consider clamping `tokens` to a minimum of 0 at the start of this helper (e.g., `tokens = max(tokens, 0)`) to avoid treating over-generated requests as having negative capacity.
</issue_to_address>
### Comment 2
<location> `python/sglang/srt/managers/schedule_policy.py:452-454` </location>
<code_context>
min(compute_dcp_local_max_new_tokens(req.sampling_params.max_new_tokens), CLIP_MAX_NEW_TOKENS)
if not truncated
else 0
</code_context>
<issue_to_address>
**suggestion (code-quality):** Swap if/else branches of if expression to remove negation ([`swap-if-expression`](https://docs.sourcery.ai/Reference/Rules-and-In-Line-Suggestions/Python/Default-Rules/swap-if-expression))
```suggestion
0 if truncated else min(compute_dcp_local_max_new_tokens(req.sampling_params.max_new_tokens), CLIP_MAX_NEW_TOKENS)
```
<br/><details><summary>Explanation</summary>Negated conditions are more difficult to read than positive ones, so it is best
to avoid them where we can. By swapping the `if` and `else` conditions around we
can invert the condition and make it positive.
</details>
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| def compute_dcp_local_max_new_tokens(tokens: int): | ||
| return (tokens + get_dcp_world_size() -1) // get_dcp_world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue (bug_risk): Guard against negative token counts passed into compute_dcp_local_max_new_tokens.
Some callers can pass a negative value here (e.g., req.sampling_params.max_new_tokens - len(req.output_ids) once len(req.output_ids) exceeds max_new_tokens). That makes (tokens + world_size - 1) // world_size negative and propagates bad values into scheduling/offset calculations. Consider clamping tokens to a minimum of 0 at the start of this helper (e.g., tokens = max(tokens, 0)) to avoid treating over-generated requests as having negative capacity.
| min(compute_dcp_local_max_new_tokens(req.sampling_params.max_new_tokens), CLIP_MAX_NEW_TOKENS) | ||
| if not truncated | ||
| else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (code-quality): Swap if/else branches of if expression to remove negation (swap-if-expression)
| min(compute_dcp_local_max_new_tokens(req.sampling_params.max_new_tokens), CLIP_MAX_NEW_TOKENS) | |
| if not truncated | |
| else 0 | |
| 0 if truncated else min(compute_dcp_local_max_new_tokens(req.sampling_params.max_new_tokens), CLIP_MAX_NEW_TOKENS) | |
Explanation
Negated conditions are more difficult to read than positive ones, so it is bestto avoid them where we can. By swapping the
if and else conditions around wecan invert the condition and make it positive.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a helper function to compute per-rank max_new_tokens for Distributed Context Parallelism (DCP) and applies this logic to the token budgeting in the scheduler. The overall approach is sound and the changes are mostly correct. However, I've identified a critical bug in the calculation of tokens_left within the add_req_state method, which could lead to incorrect memory estimations. I have also provided a suggestion to improve the maintainability of the new helper function. Addressing these points will ensure the stability and correctness of the token budgeting logic.
| tokens_left = compute_dcp_local_max_new_tokens(r.sampling_params.max_new_tokens) * new_token_ratio - len( | ||
| r.output_ids | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The calculation for tokens_left appears to be incorrect. The original logic is budgeted_tokens - generated_tokens. With DCP, this should be local_budgeted_tokens - local_generated_tokens or more accurately ceil((budgeted_tokens - generated_tokens) / dcp_world_size). The current implementation ceil(budgeted_tokens / dcp_ws) * ratio - global_generated_tokens mixes local and global token counts, which will lead to incorrect memory estimation. This could cause out-of-memory errors or underutilization of resources.
| tokens_left = compute_dcp_local_max_new_tokens(r.sampling_params.max_new_tokens) * new_token_ratio - len( | |
| r.output_ids | |
| ) | |
| tokens_left = compute_dcp_local_max_new_tokens( | |
| int(r.sampling_params.max_new_tokens * new_token_ratio) - len(r.output_ids) | |
| ) |
| from sglang.srt.distributed.parallel_state import get_dcp_world_size | ||
|
|
||
| def compute_dcp_local_max_new_tokens(tokens: int): | ||
| return (tokens + get_dcp_world_size() -1) // get_dcp_world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For better readability and to avoid potential side effects if get_dcp_world_size() were to become more complex in the future, it's good practice to call it only once and store its result in a local variable.
| return (tokens + get_dcp_world_size() -1) // get_dcp_world_size() | |
| dcp_world_size = get_dcp_world_size() | |
| return (tokens + dcp_world_size - 1) // dcp_world_size |
|
tp8 DeepSeek-V3.1 最大batch size 19, dcp可以到64,没有发更高max_concurrency的压测。 batch size 压到64时, qps提升明显,整体吞吐提升了 (0.77 / 0.47 - 1) = 63.83% <style> </style>
启动命令: |
Motivation
Modifications
Accuracy Tests
Benchmarking and Profiling
Checklist
Summary by Sourcery
Enhancements: