10#include "server-context.cpp"
17LLMService::LLMService(
const std::string &model_path,
int num_slots,
int num_threads,
int num_GPU_layers,
bool flash_attention,
int context_size,
int batch_size,
bool embedding_only,
const std::vector<std::string> &lora_paths)
19 init(
LLM::LLM_args_to_command(model_path, num_slots, num_threads, num_GPU_layers, flash_attention, context_size, batch_size, embedding_only, lora_paths));
26 llmService->
init(argv.size(), argv.data());
33 llmService->
init(command);
40 llmService->
init(argc, argv);
46 if (ctx_server !=
nullptr)
48 if (ctx_http !=
nullptr)
60 common_params default_params;
61 common_params_context ctx = common_params_parser_init(default_params, LLAMA_EXAMPLE_SERVER);
63 std::vector<std::string> args_str = {
"llm"};
64 std::set<std::string> used_keys;
66 for (
const auto &opt : ctx.options)
68 for (
const auto &name : opt.args)
70 std::string key = name;
71 if (key.rfind(
"--", 0) == 0)
73 else if (key.rfind(
"-", 0) == 0)
76 std::string json_key = key;
77 std::replace(json_key.begin(), json_key.end(),
'-',
'_');
79 if (params_json.contains(json_key))
82 used_keys.insert(json_key);
83 const auto &value = params_json[json_key];
84 args_str.push_back(name);
86 if (opt.handler_void !=
nullptr)
90 else if (opt.handler_string !=
nullptr || opt.handler_int !=
nullptr)
92 args_str.push_back(value.is_string() ? value.get<std::string>() : value.dump());
95 else if (opt.handler_str_str !=
nullptr)
97 if (!value.is_array() || value.size() != 2)
99 std::string err =
"Expected array of 2 values for: " + json_key;
100 LOG_WRN(
"%s\n", err.c_str());
103 args_str.push_back(value[0].is_string() ? value[0].get<std::string>() : value[0].dump());
104 args_str.push_back(value[1].is_string() ? value[1].get<std::string>() : value[1].dump());
111 for (
const auto &[key, _] : params_json.items())
113 if (used_keys.find(key) == used_keys.end())
115 std::string err =
"Unused parameter in JSON: " + key;
116 LOG_WRN(
"%s\n", err.c_str());
121 std::vector<std::unique_ptr<char[]>> argv_storage;
122 std::vector<char *> argv;
123 for (
const auto &arg : args_str)
125 auto buf = std::make_unique<char[]>(arg.size() + 1);
126 std::memcpy(buf.get(), arg.c_str(), arg.size() + 1);
127 argv.push_back(buf.get());
128 argv_storage.push_back(std::move(buf));
134std::vector<std::string> LLMService::splitArguments(
const std::string &inputString)
136 std::vector<std::string> arguments;
138 unsigned counter = 0;
140 std::istringstream stream_input(inputString);
141 while (std::getline(stream_input, segment,
'"'))
144 if (counter % 2 == 0)
146 if (!segment.empty())
147 arguments.push_back(segment);
151 std::istringstream stream_segment(segment);
152 while (std::getline(stream_segment, segment,
' '))
153 if (!segment.empty())
154 arguments.push_back(segment);
162 std::vector<std::string> arguments = splitArguments(
"llm " + params_string);
165 int argc =
static_cast<int>(arguments.size());
166 char **argv =
new char *[argc];
167 for (
int i = 0; i < argc; ++i)
169 argv[i] =
new char[arguments[i].size() + 1];
170 std::strcpy(argv[i], arguments[i].c_str());
177 init(std::string(params_string));
183 if (setjmp(get_jump_point()) != 0)
187 command = args_to_command(argc, argv);
194 ctx_server =
new server_context();
197 params =
new common_params();
199 params->verbosity = common_log_verbosity_thold;
200 if (!common_params_parse(argc, argv, *params, LLAMA_EXAMPLE_SERVER))
202 throw std::runtime_error(
"Invalid parameters!");
208 if (params->embedding && params->n_batch > params->n_ubatch) {
209 LOG_WRN(
"%s: embeddings enabled with n_batch (%d) > n_ubatch (%d)\n", __func__, params->n_batch, params->n_ubatch);
210 LOG_WRN(
"%s: setting n_batch = n_ubatch = %d to avoid assertion failure\n", __func__, params->n_ubatch);
211 params->n_batch = params->n_ubatch;
214 if (params->n_parallel < 0) {
215 LOG_INF(
"%s: n_parallel is set to auto, using n_parallel = 4 and kv_unified = true\n", __func__);
217 params->n_parallel = 4;
218 params->kv_unified =
true;
222 if (params->model_alias.empty() && !params->model.name.empty()) {
223 params->model_alias.insert(params->model.name);
228 llama_backend_init();
229 llama_backend_has_init =
true;
230 llama_numa_init(params->numa);
232 LLAMALIB_INF(
"system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params->cpuparams.n_threads, params->cpuparams_batch.n_threads, std::thread::hardware_concurrency());
235 params->use_jinja =
true;
236 if (!ctx_server->load_model(*params))
238 throw std::runtime_error(
"Error loading the model!");
243 ctx_http =
new server_http_context();
244 routes =
new server_routes(*params, *ctx_server);
245 routes->update_meta(*ctx_server);
250 ctx_server->impl->queue_tasks.on_new_task([
this](server_task && task)
251 { this->ctx_server->impl->process_single_task(std::move(task)); });
252 ctx_server->impl->queue_tasks.on_update_slots([
this]()
253 { this->ctx_server->impl->update_slots(); });
258 handle_exception(-1);
264 if (ctx_server !=
nullptr) ctx_server->impl->chat_params.enable_thinking =
reasoning_enabled;
292 common_log_set_verbosity_thold(debug_level - 2);
297 log_callback = callback;
300void release_slot(server_slot &slot)
302 if (slot.task && slot.task->type == SERVER_TASK_TYPE_COMPLETION)
305 slot.task->params.n_predict = 0;
306 slot.stop = STOP_TYPE_LIMIT;
307 slot.has_next_token =
false;
317 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
319 if (ctx_server->impl->slots.size() == 0)
321 return next_available_slot++ % ctx_server->impl->slots.size();
326 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
328 return ctx_server->impl->get_slot_n_ctx();
333static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) {
334 return [func = std::move(func)](
const server_http_req & req) -> server_http_res_ptr {
339 }
catch (
const std::invalid_argument & e) {
341 error = ERROR_TYPE_INVALID_REQUEST;
343 }
catch (
const std::exception & e) {
345 error = ERROR_TYPE_SERVER;
348 error = ERROR_TYPE_SERVER;
349 message =
"unknown error";
352 auto res = std::make_unique<server_http_res>();
355 json error_data = format_error_response(message, error);
356 res->status = json_value(error_data,
"code", 500);
357 res->data = safe_json_to_str({{
"error", error_data }});
358 SRV_WRN(
"got exception: %s\n", res->data.c_str());
359 }
catch (
const std::exception & e) {
360 SRV_ERR(
"got another exception: %s | while handling exception: %s\n", e.what(), message.c_str());
361 res->data =
"Internal Server Error";
369 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
374 params->hostname = host.empty() ?
"0.0.0.0" : host;
377 params->api_keys.clear();
378 if (!API_key.empty())
379 params->api_keys.push_back(API_key);
381 std::lock_guard<std::mutex> lock(start_stop_mutex);
383 if (!ctx_http->init(*params)) {
384 throw std::runtime_error(
"Failed to initialize HTTP server!");
388 ctx_http->post(
"/health", ex_wrapper(routes->get_health));
389 ctx_http->post(
"/v1/health", ex_wrapper(routes->get_health));
390 ctx_http->post(
"/props", ex_wrapper([
this](
const server_http_req &) {
return get_props();}));
391 ctx_http->post(
"/completion", ex_wrapper(routes->post_completions));
392 ctx_http->post(
"/completions", ex_wrapper(routes->post_completions));
393 ctx_http->post(
"/chat/completions", ex_wrapper(routes->post_chat_completions));
394 ctx_http->post(
"/v1/chat/completions", ex_wrapper(routes->post_chat_completions));
395 ctx_http->post(
"/tokenize", ex_wrapper(routes->post_tokenize));
396 ctx_http->post(
"/detokenize", ex_wrapper(routes->post_detokenize));
397 ctx_http->post(
"/apply-template", ex_wrapper(routes->post_apply_template));
398 ctx_http->post(
"/embedding", ex_wrapper(routes->post_embeddings));
399 ctx_http->post(
"/embeddings", ex_wrapper(routes->post_embeddings));
403 if (!ctx_http->start()) {
405 throw std::runtime_error(
"Exiting due to HTTP server error\n");
408 ctx_http->is_ready.store(
true);
419 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
421 if (ctx_http ==
nullptr)
423 std::lock_guard<std::mutex> lock(start_stop_mutex);
426 if (ctx_http->thread.joinable()) ctx_http->thread.join();
427 server_stopped =
true;
428 server_stopped_cv.notify_all();
434 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
436 std::unique_lock<std::mutex> lock(start_stop_mutex);
437 server_stopped_cv.wait(lock, [
this]
438 {
return server_stopped; });
443 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
445 std::lock_guard<std::mutex> lock(start_stop_mutex);
446 service_thread = std::thread([&]()
449 ctx_server->impl->queue_tasks.start_loop();
454 std::this_thread::sleep_for(std::chrono::milliseconds(1));
460 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
464 std::lock_guard<std::mutex> lock(start_stop_mutex);
470 for (server_slot &
slot : ctx_server->impl->slots)
475 if((!ctx_server->impl->queue_tasks.is_empty()))
479 while (!ctx_server->impl->queue_tasks.is_empty() && grace-- > 0)
481 std::this_thread::sleep_for(std::chrono::milliseconds(50));
487 ctx_server->terminate();
489 if (llama_backend_has_init)
490 llama_backend_free();
492 if (service_thread.joinable())
494 service_thread.join();
496 service_stopped =
true;
497 service_stopped_cv.notify_all();
510 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
512 std::unique_lock<std::mutex> lock(start_stop_mutex);
513 service_stopped_cv.wait(lock, [
this]
514 {
return service_stopped; });
519 return ctx_server !=
nullptr && ctx_server->impl->queue_tasks.is_running();
524 params->ssl_cert = SSL_cert_str;
525 params->ssl_key = SSL_key_str;
528std::string LLMService::encapsulate_route(
const json &body, server_http_context::handler_t route_handler)
530 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
535 server_http_req req{ {}, {},
"",
"", body.dump(), always_false };
536 return route_handler(req)->data;
547 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
549 std::vector<raw_buffer> files;
551 json data = oaicompat_chat_params_parse(
553 ctx_server->impl->chat_params,
555 return safe_json_to_str({{
"prompt", std::move(data.at(
"prompt"))}});
560 return encapsulate_route(body, routes->post_tokenize);
565 return encapsulate_route(body, routes->post_detokenize);
570 return encapsulate_route(body, routes->post_embeddings);
575 return safe_json_to_str(encapsulate_route(body, routes->post_lora_adapters));
580 return encapsulate_route({}, routes->get_lora_adapters);
585 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
590 bool stream = json_value(data_in,
"stream", callback !=
nullptr);
592 data[
"stream"] = stream;
594 server_http_req req{ {}, {},
"",
"", data.dump(), always_false };
595 auto result = routes->post_completions(req);
596 if (result->status != 200)
604 if (callback) concatenator.
set_callback(callback, callbackWithJSON);
607 bool has_next = result->next(chunk);
608 if (!chunk.empty()) {
611 if (!has_next)
break;
627 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
629 std::string result_data =
"";
632 server_task_type task_type;
633 std::string action = data.at(
"action");
634 if (action ==
"save")
636 task_type = SERVER_TASK_TYPE_SLOT_SAVE;
638 else if (action ==
"restore")
640 task_type = SERVER_TASK_TYPE_SLOT_RESTORE;
642 else if (action ==
"erase")
644 task_type = SERVER_TASK_TYPE_SLOT_ERASE;
648 throw std::runtime_error(
"Invalid action" + action);
651 int id_slot = json_value(data,
"id_slot", 0);
653 server_task task(task_type);
654 task.id = ctx_server->impl->queue_tasks.get_new_id();
655 task.slot_action.id_slot = id_slot;
657 if (action ==
"save" || action ==
"restore")
659 std::string filepath = data.at(
"filepath");
660 task.slot_action.filename = filepath.substr(filepath.find_last_of(
"/\\") + 1);
661 task.slot_action.filepath = filepath;
664 ctx_server->impl->queue_results.add_waiting_task_id(task.id);
665 ctx_server->impl->queue_tasks.post(std::move(task));
667 server_task_result_ptr result = ctx_server->impl->queue_results.recv(task.id);
668 ctx_server->impl->queue_results.remove_waiting_task_id(task.id);
670 json result_json = result->to_json();
671 result_data = result_json.dump();
682 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
686 for (
auto &
slot : ctx_server->impl->slots)
688 if (
slot.id == id_slot)
703 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
707 if (ctx_server ==
nullptr)
return 0;
708 return ctx_server->get_meta().model_n_embd_inp;
712 if (get_status_code() < 0 || setjmp(get_jump_point()) != 0)
715 server_http_req req{ {}, {},
"",
"",
"", always_false };
716 auto result = routes->get_props(req);
718 json data = json::parse(result->data);
722 n_ctx = data.at(
"default_generation_settings").at(
"n_ctx").get<
int>();
726 result->data = safe_json_to_str(json {
727 {
"default_generation_settings", {
742bool LLMService_Supports_GPU()
744 return llama_supports_gpu_offload();
747LLMService *
LLMService_Construct(
const char *model_path,
int num_slots,
int num_threads,
int num_GPU_layers,
bool flash_attention,
int context_size,
int batch_size,
bool embedding_only,
int lora_count,
const char **lora_paths)
749 std::vector<std::string> lora_paths_vector;
750 if (lora_paths !=
nullptr && lora_count > 0)
752 for (
int i = 0; i < lora_count; ++i)
754 lora_paths_vector.push_back(std::string(lora_paths[i]));
757 LLMService* llmService =
new LLMService(model_path, num_slots, num_threads, num_GPU_layers, flash_attention, context_size, batch_size, embedding_only, lora_paths_vector);
758 if (get_status_code() != 0)
760 if (llmService !=
nullptr)
delete llmService;
769 std::string params_string(params_string_arr);
772 json j = json::parse(params_string);
775 catch (
const json::parse_error &)
780 if (get_status_code() != 0)
782 if (llmService !=
nullptr)
delete llmService;
790 return stringToCharArray(llm_service->
get_command());
793void LLMService_InjectErrorState(
ErrorState *error_state)
void ensure_error_handlers_initialized()
Ensures error handlers are properly initialized.
LLM service implementation with server capabilities.
#define LLAMALIB_INF(...)
Info-level logging macro for LLama library.
static void inject_error_state(ErrorState *state)
Inject a custom error state instance.
virtual std::string slot(int id_slot, const std::string &action, const std::string &filepath)
Perform slot operation.
Registry for managing LLM provider instances.
void unregister_instance(LLMProvider *service)
Unregister an LLM provider instance.
const int get_debug_level()
Get current debug level.
const CharArrayFn get_log_callback()
Get current log callback.
void register_instance(LLMProvider *service)
Register an LLM provider instance.
static LLMProviderRegistry & instance()
Get the singleton registry instance.
static void inject_registry(LLMProviderRegistry *instance)
Inject a custom registry instance.
bool reasoning_enabled
Whether reasoning is enabled.
virtual void enable_reasoning(bool reasoning)
enable reasoning
Runtime loader for LLM libraries.
void init(int argc, char **argv)
Initialize from argc/argv parameters.
void stop_server() override
Stop HTTP server (override - delegates to loaded library)
void set_SSL(const std::string &cert, const std::string &key) override
Set SSL configuration (override - delegates to loaded library)
void enable_reasoning(bool reasoning) override
enable reasoning
std::string lora_weight_json(const json &data) override
Configure LoRA weights with HTTP response support.
static std::vector< char * > jsonToArguments(const json ¶ms_json)
Convert JSON parameters to command line arguments.
void join_service() override
Wait for service completion (override - delegates to loaded library)
void cancel(int id_slot) override
Cancel request (override - delegates to loaded library)
bool started() override
Check service status (override - delegates to loaded library)
void start() override
Start service (override - delegates to loaded library)
std::string lora_list_json() override
List available LoRA adapters.
void logging_callback(CharArrayFn callback) override
Set logging callback (override - delegates to loaded library)
std::string tokenize_json(const json &data) override
Tokenize input (override)
std::unique_ptr< server_http_res > get_props()
Return properties of server / slots.
LLMService()
Default constructor.
std::string slot_json(const json &data) override
Manage slots with HTTP response support.
std::string detokenize_json(const json &data) override
Convert tokens back to text.
std::string embeddings_json(const json &data) override
Generate embeddings with HTTP response support.
int get_next_available_slot() override
Get available slot (override - delegates to loaded library)
void debug(int debug_level) override
Set debug level (override - delegates to loaded library)
void join_server() override
Wait for server completion (override - delegates to loaded library)
static LLMService * from_params(const json ¶ms_json)
Create LLMService from JSON parameters.
std::string apply_template_json(const json &data) override
Apply a chat template to message data.
void start_server(const std::string &host="0.0.0.0", int port=-1, const std::string &API_key="") override
Start HTTP server (override - delegates to loaded library)
std::string completion_json(const json &data, CharArrayFn callback=nullptr, bool callbackWithJSON=true) override
Generate completion (override - delegates to loaded library)
std::string get_command()
Returns the construct command.
int get_slot_context_size() override
Get slot context size (override - delegates to loaded library)
void stop() override
Stop service (override - delegates to loaded library)
int embedding_size() override
Get embedding size (override - delegates to loaded library)
static LLMService * from_command(const std::string &command)
Create runtime from command line string.
static std::string LLM_args_to_command(const std::string &model_path, int num_slots=1, int num_threads=-1, int num_GPU_layers=0, bool flash_attention=false, int context_size=4096, int batch_size=2048, bool embedding_only=false, const std::vector< std::string > &lora_paths={})
Convert LLM parameters to command line arguments.
Handles concatenation of LLM response chunks (both streaming and non-streaming) Accumulates content a...
bool process_chunk(const std::string &chunk_data)
Process a single chunk and accumulate its content/tokens.
bool is_complete() const
Check if response is complete.
std::string get_result_json() const
Get the complete result as JSON string.
void set_callback(CharArrayFn callback, bool callWithJSON=false)
Set a callback to be invoked after each chunk is processed.
LLMService * LLMService_Construct(const char *model_path, int num_slots=1, int num_threads=-1, int num_GPU_layers=0, bool flash_attention=false, int context_size=4096, int batch_size=2048, bool embedding_only=false, int lora_count=0, const char **lora_paths=nullptr)
Construct LLMService instance (C API)
void LLMService_Registry(LLMProviderRegistry *existing_instance)
Set registry for LLMService (C API)
const char * LLMService_Command(LLMService *llm_service)
Returns the construct command (C API)
LLMService * LLMService_From_Command(const char *params_string)
Create LLMService from command string (C API)
Error state container for sharing between libraries.