-
Notifications
You must be signed in to change notification settings - Fork 663
[Feature] support stop_token_ids #5399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lizexu123
wants to merge
6
commits into
PaddlePaddle:develop
Choose a base branch
from
lizexu123:support_token_ids_3
base: develop
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
4d054ce
support stop_token_ids
lizexu123 a66ecf0
fix
lizexu123 5810bb6
delete chinese
lizexu123 48e74d6
Merge branch 'develop' of https://github.com/PaddlePaddle/FastDeploy …
lizexu123 5a213e1
support both
lizexu123 93e3eaf
delete print
lizexu123 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,59 +37,67 @@ __global__ void set_value_by_flags(bool *stop_flags, | |
| const int *stop_seqs_len, | ||
| const int stop_seqs_bs, | ||
| const int stop_seqs_max_len, | ||
| const int64_t *min_tokens, | ||
| bool beam_search, | ||
| bool prefill_one_step_stop) { | ||
| int tid = threadIdx.x; | ||
| int bid = blockIdx.x; | ||
| if (tid >= stop_seqs_bs) return; | ||
| if (bid < bs) { | ||
| if(tid == 0){ | ||
| if (prefill_one_step_stop) { | ||
| stop_flags[bid] = true; | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } else { | ||
| if (stop_flags[bid]) { | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } else { | ||
| topk_ids[bid] = end_ids[0]; | ||
| next_tokens[bid] = end_ids[0]; | ||
| } | ||
| } else { | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } | ||
| } | ||
| if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
| stop_flags[bid] = true; | ||
| topk_ids[bid] = end_ids[0]; | ||
| next_tokens[bid] = end_ids[0]; | ||
| } | ||
| int tid = threadIdx.x; | ||
| int bid = blockIdx.x; | ||
| if (tid >= stop_seqs_bs) return; | ||
| if (bid < bs) { | ||
| const int64_t current_step = step_idx[bid]; | ||
| const int64_t min_token_limit = min_tokens[bid]; | ||
| const bool can_stop = (current_step >= min_token_limit); | ||
| if (tid == 0) { | ||
| if (prefill_one_step_stop) { | ||
| stop_flags[bid] = true; | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } | ||
| // dealing stop_seqs | ||
| const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; | ||
| if (stop_seq_len <= 0) return; | ||
| const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; | ||
| const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; | ||
| const int64_t step_idx_now = step_idx[bid]; | ||
|
|
||
| bool is_end = true; | ||
| int count = 1; | ||
| for (int i = stop_seq_len - 1; i >= 0; --i) { | ||
| if ((step_idx_now - count) < 0 || | ||
| pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { | ||
| is_end = false; | ||
| break; | ||
| } | ||
| } | ||
| if (is_end) { | ||
| next_tokens[bid] = end_ids[0]; | ||
| stop_flags[bid] = true; | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } else { | ||
| if (stop_flags[bid]) { | ||
| if (seq_lens[bid] == 0) { | ||
| topk_ids[bid] = -1; | ||
| } else { | ||
| topk_ids[bid] = end_ids[0]; | ||
| next_tokens[bid] = end_ids[0]; | ||
| } | ||
| } else { | ||
| next_tokens[bid] = topk_ids[bid]; | ||
| } | ||
| } | ||
| if (!beam_search && can_stop && | ||
| is_in_end(topk_ids[bid], end_ids, end_length)) { | ||
| stop_flags[bid] = true; | ||
| topk_ids[bid] = end_ids[0]; | ||
| next_tokens[bid] = end_ids[0]; | ||
| } | ||
| } | ||
|
|
||
| if (!can_stop) return; | ||
| // dealing stop_seqs | ||
| const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid]; | ||
| if (stop_seq_len <= 0) return; | ||
| const int64_t *stop_seq_now = | ||
| stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len; | ||
| const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len; | ||
| const int64_t step_idx_now = step_idx[bid]; | ||
|
|
||
| bool is_end = true; | ||
| int count = 1; | ||
| for (int i = stop_seq_len - 1; i >= 0; --i) { | ||
| if ((step_idx_now - count) < 0 || | ||
| pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) { | ||
| is_end = false; | ||
| break; | ||
| } | ||
| } | ||
| if (is_end) { | ||
| next_tokens[bid] = end_ids[0]; | ||
| stop_flags[bid] = true; | ||
| topk_ids[bid] = end_ids[0]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | ||
|
|
@@ -101,50 +109,63 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids, | |
| const paddle::Tensor &step_idx, | ||
| const paddle::Tensor &stop_seqs, | ||
| const paddle::Tensor &stop_seqs_len, | ||
| const paddle::Tensor &min_tokens, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 引入min_tokens的作用是什么?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 设置最小生成的token数量,如果当前生成的token数量小于min_tokens,即使设置了stop_token_ids,也不会停止 |
||
| const bool beam_search) { | ||
| PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); | ||
| PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); | ||
| bool prefill_one_step_stop = false; | ||
| if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { | ||
| // std::cout << "Your PATH is: " << env_p << '\n'; | ||
| if (env_p[0] == '1') { | ||
| prefill_one_step_stop = true; | ||
| } | ||
| PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); | ||
| PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); | ||
| bool prefill_one_step_stop = false; | ||
| if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) { | ||
| // std::cout << "Your PATH is: " << env_p << '\n'; | ||
| if (env_p[0] == '1') { | ||
| prefill_one_step_stop = true; | ||
| } | ||
| } | ||
|
|
||
| #ifdef PADDLE_WITH_CUSTOM_DEVICE | ||
| auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place())); | ||
| auto cu_stream = dev_ctx->stream(); | ||
| auto dev_ctx = static_cast<const phi::CustomContext *>( | ||
| paddle::experimental::DeviceContextPool::Instance().Get( | ||
| topk_ids.place())); | ||
| auto cu_stream = dev_ctx->stream(); | ||
| #else | ||
| auto cu_stream = topk_ids.stream(); | ||
| auto cu_stream = topk_ids.stream(); | ||
| #endif | ||
| std::vector<int64_t> shape = topk_ids.shape(); | ||
| int64_t bs_now = shape[0]; | ||
| int64_t end_length = end_ids.shape()[0]; | ||
| int stop_seqs_bs = stop_seqs.shape()[1]; | ||
| int stop_seqs_max_len = stop_seqs.shape()[2]; | ||
| int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
| set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>( | ||
| const_cast<bool *>(stop_flags.data<bool>()), | ||
| const_cast<int64_t *>(topk_ids.data<int64_t>()), | ||
| const_cast<int64_t *>(next_tokens.data<int64_t>()), | ||
| end_ids.data<int64_t>(), | ||
| seq_lens.data<int>(), | ||
| bs_now, | ||
| end_length, | ||
| pre_ids.data<int64_t>(), | ||
| pre_ids.shape()[1], | ||
| step_idx.data<int64_t>(), | ||
| stop_seqs.data<int64_t>(), | ||
| stop_seqs_len.data<int>(), | ||
| stop_seqs_bs, | ||
| stop_seqs_max_len, | ||
| beam_search, | ||
| prefill_one_step_stop); | ||
| std::vector<int64_t> shape = topk_ids.shape(); | ||
| int64_t bs_now = shape[0]; | ||
| int64_t end_length = end_ids.shape()[0]; | ||
| int stop_seqs_bs = stop_seqs.shape()[1]; | ||
| int stop_seqs_max_len = stop_seqs.shape()[2]; | ||
| int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE; | ||
| set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>( | ||
| const_cast<bool *>(stop_flags.data<bool>()), | ||
| const_cast<int64_t *>(topk_ids.data<int64_t>()), | ||
| const_cast<int64_t *>(next_tokens.data<int64_t>()), | ||
| end_ids.data<int64_t>(), | ||
| seq_lens.data<int>(), | ||
| bs_now, | ||
| end_length, | ||
| pre_ids.data<int64_t>(), | ||
| pre_ids.shape()[1], | ||
| step_idx.data<int64_t>(), | ||
| stop_seqs.data<int64_t>(), | ||
| stop_seqs_len.data<int>(), | ||
| stop_seqs_bs, | ||
| stop_seqs_max_len, | ||
| min_tokens.data<int64_t>(), | ||
| beam_search, | ||
| prefill_one_step_stop); | ||
| } | ||
|
|
||
| PD_BUILD_STATIC_OP(set_stop_value_multi_ends) | ||
| .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"}) | ||
| .Inputs({"topk_ids", | ||
| "stop_flags", | ||
| "seq_lens", | ||
| "end_ids", | ||
| "next_tokens", | ||
| "pre_ids", | ||
| "step_idx", | ||
| "stop_seqs", | ||
| "stop_seqs_len", | ||
| "min_tokens"}) | ||
| .Attrs({"beam_search: bool"}) | ||
| .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) | ||
| .SetInplaceMap({{"topk_ids", "topk_ids_out"}, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
两个算子改参数的话,记得把 ernie5_serving 同步改了