|
|
@@ -86,6 +86,7 @@ enum error_type {
|
|
|
ERROR_TYPE_PERMISSION,
|
|
|
ERROR_TYPE_UNAVAILABLE, // custom error
|
|
|
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
|
|
+ ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error
|
|
|
};
|
|
|
|
|
|
static bool server_task_type_need_embd(server_task_type task_type) {
|
|
|
@@ -1224,6 +1225,10 @@ static json format_error_response(const std::string & message, const enum error_
|
|
|
type_str = "unavailable_error";
|
|
|
code = 503;
|
|
|
break;
|
|
|
+ case ERROR_TYPE_EXCEED_CONTEXT_SIZE:
|
|
|
+ type_str = "exceed_context_size_error";
|
|
|
+ code = 400;
|
|
|
+ break;
|
|
|
}
|
|
|
return json {
|
|
|
{"code", code},
|
|
|
@@ -1237,12 +1242,21 @@ struct server_task_result_error : server_task_result {
|
|
|
error_type err_type = ERROR_TYPE_SERVER;
|
|
|
std::string err_msg;
|
|
|
|
|
|
+ // for ERROR_TYPE_EXCEED_CONTEXT_SIZE
|
|
|
+ int32_t n_prompt_tokens = 0;
|
|
|
+ int32_t n_ctx = 0;
|
|
|
+
|
|
|
virtual bool is_error() override {
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
virtual json to_json() override {
|
|
|
- return format_error_response(err_msg, err_type);
|
|
|
+ json res = format_error_response(err_msg, err_type);
|
|
|
+ if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
|
|
|
+ res["n_prompt_tokens"] = n_prompt_tokens;
|
|
|
+ res["n_ctx"] = n_ctx;
|
|
|
+ }
|
|
|
+ return res;
|
|
|
}
|
|
|
};
|
|
|
|
|
|
@@ -2605,16 +2619,22 @@ struct server_context {
|
|
|
}
|
|
|
|
|
|
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
|
- send_error(slot.id_task, error, type);
|
|
|
+ send_error(slot.id_task, error, type, slot.n_prompt_tokens, slot.n_ctx);
|
|
|
}
|
|
|
|
|
|
- void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
|
+ void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) {
|
|
|
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
|
|
|
|
|
+ if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) {
|
|
|
+ GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0);
|
|
|
+ }
|
|
|
+
|
|
|
auto res = std::make_unique<server_task_result_error>();
|
|
|
- res->id = id_task;
|
|
|
- res->err_type = type;
|
|
|
- res->err_msg = error;
|
|
|
+ res->id = id_task;
|
|
|
+ res->err_type = type;
|
|
|
+ res->err_msg = error;
|
|
|
+ res->n_prompt_tokens = n_prompt_tokens;
|
|
|
+ res->n_ctx = n_ctx;
|
|
|
|
|
|
queue_results.send(std::move(res));
|
|
|
}
|
|
|
@@ -3286,7 +3306,7 @@ struct server_context {
|
|
|
|
|
|
if (slot.n_prompt_tokens > slot.n_ctx) {
|
|
|
slot.release();
|
|
|
- send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
|
|
|
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
|
|
continue;
|
|
|
}
|
|
|
} else {
|
|
|
@@ -3296,7 +3316,7 @@ struct server_context {
|
|
|
// context shift should be applied only during the generation phase
|
|
|
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
|
|
slot.release();
|
|
|
- send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
|
|
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
|
|
continue;
|
|
|
}
|
|
|
}
|