WARNING: THIS SITE IS A MIRROR OF GITHUB.COM / IT CANNOT LOGIN OR REGISTER ACCOUNTS / THE CONTENTS ARE PROVIDED AS-IS / THIS SITE ASSUMES NO RESPONSIBILITY FOR ANY DISPLAYED CONTENT OR LINKS / IF YOU FOUND SOMETHING MAY NOT GOOD FOR EVERYONE, CONTACT ADMIN AT ilovescratch@foxmail.com
Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ void GetStopFlagsMulti(const paddle::Tensor& topk_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& stop_seqs,
const paddle::Tensor& stop_seqs_len,
const paddle::Tensor& min_tokens,
const bool beam_search);

void UpdateInputs(const paddle::Tensor& stop_flags,
Expand Down Expand Up @@ -763,7 +764,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor& accept_tokens,
const paddle::Tensor& seq_lens,
const paddle::Tensor& stop_seqs,
const paddle::Tensor& stop_seqs_len,
const paddle::Tensor& end_ids);
const paddle::Tensor& end_ids,
const paddle::Tensor& min_tokens);

void SpeculateVerify(const paddle::Tensor& sampled_token_ids,
const paddle::Tensor& accept_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
const int accept_tokens_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *min_tokens,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

两个算子改参数的话,记得把 ernie5_serving 同步改了

const int pre_ids_len) {
const int bid = blockIdx.x;
const int tid = threadIdx.x;
Expand All @@ -46,6 +47,10 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
int64_t *accept_tokens_now = accept_tokens + bid * accept_tokens_len;
const int accept_num = accept_nums[bid];
const int64_t step_idx_now = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];

const bool can_stop = (step_idx_now >= min_token_limit);
if (!can_stop) return;
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
Expand Down Expand Up @@ -138,7 +143,8 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
const paddle::Tensor &seq_lens,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &end_ids) {
const paddle::Tensor &end_ids,
const paddle::Tensor &min_tokens) {
PD_CHECK(accept_tokens.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);

Expand Down Expand Up @@ -166,6 +172,7 @@ void SpecGetStopFlagsMultiSeqs(const paddle::Tensor &accept_tokens,
accept_tokens_len,
stop_seqs_bs,
stop_seqs_max_len,
min_tokens.data<int64_t>(),
pre_ids_len);
}

Expand All @@ -178,7 +185,8 @@ PD_BUILD_STATIC_OP(speculate_set_stop_value_multi_seqs)
"seq_lens",
"stop_seqs",
"stop_seqs_len",
"end_ids"})
"end_ids",
"min_tokens"})
.Outputs({"accept_tokens_out", "stop_flags_out"})
.SetInplaceMap({{"accept_tokens", "accept_tokens_out"},
{"stop_flags", "stop_flags_out"}})
Expand Down
185 changes: 103 additions & 82 deletions custom_ops/gpu_ops/stop_generation_multi_ends.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,59 +37,67 @@ __global__ void set_value_by_flags(bool *stop_flags,
const int *stop_seqs_len,
const int stop_seqs_bs,
const int stop_seqs_max_len,
const int64_t *min_tokens,
bool beam_search,
bool prefill_one_step_stop) {
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
if(tid == 0){
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
int tid = threadIdx.x;
int bid = blockIdx.x;
if (tid >= stop_seqs_bs) return;
if (bid < bs) {
const int64_t current_step = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];
const bool can_stop = (current_step >= min_token_limit);
if (tid == 0) {
if (prefill_one_step_stop) {
stop_flags[bid] = true;
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
}
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now = stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];

bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
next_tokens[bid] = topk_ids[bid];
} else {
if (stop_flags[bid]) {
if (seq_lens[bid] == 0) {
topk_ids[bid] = -1;
} else {
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
} else {
next_tokens[bid] = topk_ids[bid];
}
}
if (!beam_search && can_stop &&
is_in_end(topk_ids[bid], end_ids, end_length)) {
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
next_tokens[bid] = end_ids[0];
}
}

if (!can_stop) return;
// dealing stop_seqs
const int stop_seq_len = (stop_seqs_len + bid * stop_seqs_bs)[tid];
if (stop_seq_len <= 0) return;
const int64_t *stop_seq_now =
stop_seqs + bid * stop_seqs_bs + tid * stop_seqs_max_len;
const int64_t *pre_ids_now = pre_ids + bid * pre_ids_len;
const int64_t step_idx_now = step_idx[bid];

bool is_end = true;
int count = 1;
for (int i = stop_seq_len - 1; i >= 0; --i) {
if ((step_idx_now - count) < 0 ||
pre_ids_now[step_idx_now - count++] != stop_seq_now[i]) {
is_end = false;
break;
}
}
if (is_end) {
next_tokens[bid] = end_ids[0];
stop_flags[bid] = true;
topk_ids[bid] = end_ids[0];
}
}
}

void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
Expand All @@ -101,50 +109,63 @@ void GetStopFlagsMulti(const paddle::Tensor &topk_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &stop_seqs,
const paddle::Tensor &stop_seqs_len,
const paddle::Tensor &min_tokens,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

引入min_tokens的作用是什么?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

设置最小生成的token数量,如果当前生成的token数量小于min_tokens,即使设置了stop_token_ids,也不会停止

const bool beam_search) {
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64);
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
bool prefill_one_step_stop = false;
if (const char *env_p = std::getenv("PREFILL_NODE_ONE_STEP_STOP")) {
// std::cout << "Your PATH is: " << env_p << '\n';
if (env_p[0] == '1') {
prefill_one_step_stop = true;
}
}

#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_ctx = static_cast<const phi::CustomContext*>(paddle::experimental::DeviceContextPool::Instance().Get(topk_ids.place()));
auto cu_stream = dev_ctx->stream();
auto dev_ctx = static_cast<const phi::CustomContext *>(
paddle::experimental::DeviceContextPool::Instance().Get(
topk_ids.place()));
auto cu_stream = dev_ctx->stream();
#else
auto cu_stream = topk_ids.stream();
auto cu_stream = topk_ids.stream();
#endif
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
beam_search,
prefill_one_step_stop);
std::vector<int64_t> shape = topk_ids.shape();
int64_t bs_now = shape[0];
int64_t end_length = end_ids.shape()[0];
int stop_seqs_bs = stop_seqs.shape()[1];
int stop_seqs_max_len = stop_seqs.shape()[2];
int block_size = (stop_seqs_bs + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE;
set_value_by_flags<<<bs_now, block_size, 0, cu_stream>>>(
const_cast<bool *>(stop_flags.data<bool>()),
const_cast<int64_t *>(topk_ids.data<int64_t>()),
const_cast<int64_t *>(next_tokens.data<int64_t>()),
end_ids.data<int64_t>(),
seq_lens.data<int>(),
bs_now,
end_length,
pre_ids.data<int64_t>(),
pre_ids.shape()[1],
step_idx.data<int64_t>(),
stop_seqs.data<int64_t>(),
stop_seqs_len.data<int>(),
stop_seqs_bs,
stop_seqs_max_len,
min_tokens.data<int64_t>(),
beam_search,
prefill_one_step_stop);
}

PD_BUILD_STATIC_OP(set_stop_value_multi_ends)
.Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens", "pre_ids", "step_idx", "stop_seqs", "stop_seqs_len"})
.Inputs({"topk_ids",
"stop_flags",
"seq_lens",
"end_ids",
"next_tokens",
"pre_ids",
"step_idx",
"stop_seqs",
"stop_seqs_len",
"min_tokens"})
.Attrs({"beam_search: bool"})
.Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"})
.SetInplaceMap({{"topk_ids", "topk_ids_out"},
Expand Down
41 changes: 40 additions & 1 deletion docs/features/early_stop.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Early Stopping

The early stopping is used to prematurely terminate the token generation of the model. Specifically, the early stopping uses different strategies to determine whether the currently generated token sequence meets the early stopping criteria. If so, token generation is terminated prematurely. FastDeploy currently supports the repetition strategy and stop sequence.
The early stopping is used to prematurely terminate the token generation of the model. Specifically, the early stopping uses different strategies to determine whether the currently generated token sequence meets the early stopping criteria. If so, token generation is terminated prematurely. FastDeploy currently supports the repetition strategy and stop sequence and stop_token_ids.

## 1. Repetition Strategy
* The repetition strategy determines whether to trigger the early stopping function by checking the number of times a high-probability token is generated.
Expand Down Expand Up @@ -121,3 +121,42 @@ output = llm.chat(messages=[{"role": "user", "content": "今天天气真好"}],
print(output)
```

## 3. Stop_token_ids
* The Stop_token_ids strategy determines whether to trigger early stopping by checking whether the generated token sequence contains a user-specified stop token_id.

* Specifically, if the token sequence generated by a batch contains a user-specified stop_token_ids, token generation for that batch is terminated prematurely.

### Usage Instructions

request with stop_token_ids parameter, it can be List[int]

* online serving, set `stop_token_ids` parameter in request
```
# create a chat request with "stop_token_ids" parameter
curl -X POST "http://0.0.0.0:13312/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "default",
"messages": [
{
"role": "user",
"content": "北京天安门在哪里?"
}
],
"temperature": 0.7,
"stream": false,
"seed": 1,
"stop_token_ids":[104208]
}'
```

* offline LLM, set `stop_token_ids` parameter in `SamplingParams`
```
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
model_name_or_path = "/Qwen/Qwen3-0.6B"
sampling_params = SamplingParams(temperature=1, seed=1,stop_token_ids=[104208])
llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
output = llm.chat(messages=[{"role": "user", "content": "北京天安门在哪里?"}], use_tqdm=True, sampling_params=sampling_params)
print(output)
42 changes: 41 additions & 1 deletion docs/zh/features/early_stop.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# 早停功能

早停功能用于提前结束模型生成token的过程,具体来说早停功能会采取不同的策略,判断当前生成的token序列是否满足早停条件,如果满足则提前结束token生成。FastDeploy目前支持`Repetition`策略和`Stop Sequence`策略。
早停功能用于提前结束模型生成token的过程,具体来说早停功能会采取不同的策略,判断当前生成的token序列是否满足早停条件,如果满足则提前结束token生成。FastDeploy目前支持`Repetition`策略和`Stop Sequence`策略`Stop_token_ids`策略

## 1.Repetition策略
* Repetition策略通过检查生成高概率token的次数决定是否需要触发早停功能。
Expand Down Expand Up @@ -116,3 +116,43 @@ output = llm.chat(messages=[{"role": "user", "content": "今天天气真好"}],
print(output)

```
## 3.Stop_token_ids策略
* Stop token ids策略通过检查生成的token序列是否包含用户指定的停止token id决定是否需要触发早停功能。
* 具体来说,当某个batch生成的token序列中包含用户指定的停止token_id时,将提前结束该batch的token生成过程。
### 使用说明
在请求服务时,在请求中包含`stop_token_ids`字段,是`List[int]`。
* 在线推理请求示例,请求时添加stop_token_ids参数
```
# create a chat request with "stop_token_ids" parameter

curl -X POST "http://0.0.0.0:13312/v1/chat/completions" \
-H "Content-Type: application/json" \
-d '{
"model": "default",
"messages": [
{
"role": "user",
"content": "北京天安门在哪里?"
}
],
"temperature": 0.7,
"stream": false,
"seed": 1,
"stop_token_ids":[104208]
}'
```
* 离线推理请求,在`SamplingParams`中增加`stop_token_ids`参数
```
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM

model_name_or_path = "/root/paddlejob/workspace/env_run/output/models/paddle/Qwen/Qwen3-0.6B"

# 超参设置
sampling_params = SamplingParams(temperature=1, seed=1,stop_token_ids=[104208])
llm = LLM(model=model_name_or_path, tensor_parallel_size=1)
output = llm.chat(messages=[{"role": "user", "content": "北京天安门在哪里?"}], use_tqdm=True, sampling_params=sampling_params)

print(output)
Loading
Loading