LlamaLib  v2.0.2
Cross-platform library for local LLMs
Loading...
Searching...
No Matches
LLM_client.cpp
1#include "LLM_client.h"
2
3//================ Remote requests ================//
4
5#if !(TARGET_OS_IOS || TARGET_OS_VISION)
6X509_STORE *load_client_cert(const std::string &cert_str)
7{
8 BIO *mem = BIO_new_mem_buf(cert_str.data(), (int)cert_str.size());
9 if (!mem)
10 {
11 return nullptr;
12 }
13
14 auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr);
15 if (!inf)
16 {
17 return nullptr;
18 }
19
20 auto cts = X509_STORE_new();
21 if (cts)
22 {
23 for (auto i = 0; i < static_cast<int>(sk_X509_INFO_num(inf)); i++)
24 {
25 auto itmp = sk_X509_INFO_value(inf, i);
26 if (!itmp)
27 {
28 continue;
29 }
30
31 if (itmp->x509)
32 {
33 X509_STORE_add_cert(cts, itmp->x509);
34 }
35 if (itmp->crl)
36 {
37 X509_STORE_add_crl(cts, itmp->crl);
38 }
39 }
40 }
41
42 sk_X509_INFO_pop_free(inf, X509_INFO_free);
43 BIO_free(mem);
44 return cts;
45}
46#else
47struct IOSCallbackContext {
48 ResponseConcatenator* concatenator;
49 bool* cancel_flag;
50};
51
52// Static callback for iOS that receives context
53static void ios_callback_with_context(const char* data, void* ctx) {
54 auto* context = static_cast<IOSCallbackContext*>(ctx);
55
56 if (!context || !context->concatenator || !context->cancel_flag) {
57 return;
58 }
59
60 if (*context->cancel_flag) {
61 return;
62 }
63
64 std::string chunk_str(data);
65 if (!context->concatenator->process_chunk(chunk_str)) {
66 *context->cancel_flag = true;
67 }
68}
69#endif
70
71bool LLMClient::is_server_alive()
72{
73 if (!is_remote()) return true;
74
75 std::vector<std::pair<std::string, std::string>> headers;
76 if (!API_key.empty()) {
77 headers.push_back({"Authorization", "Bearer " + API_key});
78 }
79
80#if TARGET_OS_IOS || TARGET_OS_VISION
81 HttpResult result = transport->post_request("health", "{}", headers);
82 return result.success && result.status_code >= 200 && result.status_code < 300;
83#else
84 httplib::Headers Headers;
85 for (const auto& h : headers) Headers.insert(h);
86 auto res = use_ssl ? sslClient->Post("/health", Headers) : client->Post("/health", Headers);
87 return res && res->status >= 200 && res->status < 300;
88#endif
89}
90
91std::string LLMClient::post_request(
92 const std::string &path,
93 const json &payload,
94 CharArrayFn callback,
95 bool callbackWithJSON)
96{
97 json body = payload;
98 bool stream = callback != nullptr;
99 if (body.contains("stream"))
100 stream = body["stream"];
101 else
102 body["stream"] = stream;
103
104 bool* cancel_flag = new bool(false);
105 if (stream) active_requests.push_back(cancel_flag);
106
107 std::string response_buffer = "";
108 ResponseConcatenator concatenator;
109 if (stream && callback) concatenator.set_callback(callback, callbackWithJSON);
110
111 std::vector<std::pair<std::string, std::string>> headers = {
112 {"Content-Type", "application/json"},
113 {"Accept", stream ? "text/event-stream" : "application/json"},
114 {"Cache-Control", "no-cache"}
115 };
116
117 if (!API_key.empty()) {
118 headers.push_back({"Authorization", "Bearer " + API_key});
119 }
120
121#if TARGET_OS_IOS || TARGET_OS_VISION
122 // iOS Native Implementation with context
123 IOSCallbackContext ios_context = {&concatenator, cancel_flag};
124
125 HttpResult result;
126 for (int attempt = 0; attempt <= max_retries; attempt++) {
127 result = transport->post_request(
128 path,
129 body.dump(),
130 headers,
131 stream ? ios_callback_with_context : nullptr,
132 stream ? &ios_context : nullptr,
133 cancel_flag
134 );
135
136 if (result.success || *cancel_flag) break;
137
138 int delay_seconds = std::min(30, 1 << attempt);
139 std::cerr << "[LLMClient] POST failed: " << result.error_message
140 << ", retrying in " << delay_seconds << "s (attempt "
141 << attempt << "/" << max_retries << ")\n";
142 std::this_thread::sleep_for(std::chrono::seconds(delay_seconds));
143 }
144
145 if (!result.success) {
146 std::cerr << "[LLMClient] POST request failed: " << result.error_message << "\n";
147 if (stream) {
148 active_requests.erase(std::remove(active_requests.begin(), active_requests.end(), cancel_flag), active_requests.end());
149 }
150 delete cancel_flag;
151 return "{}";
152 }
153
154 if (stream) {
155 active_requests.erase(std::remove(active_requests.begin(), active_requests.end(), cancel_flag), active_requests.end());
156 }
157 delete cancel_flag;
158
159 return stream ? concatenator.get_result_json() : result.body;
160
161#else
162 // cpp-httplib implementation (unchanged)
163 httplib::Headers Headers;
164 for (const auto& h : headers) Headers.insert(h);
165
166 httplib::Request req;
167 req.method = "POST";
168 req.path = "/" + path;
169 req.headers = Headers;
170 req.body = body.dump();
171
172 req.content_receiver = [&](const char *data, size_t data_length, uint64_t /*offset*/, uint64_t /*total_length*/)
173 {
174 std::string chunk_str(data, data_length);
175 if (stream)
176 {
177 if (!concatenator.process_chunk(chunk_str)) {
178 return false;
179 }
180 if (*cancel_flag)
181 {
182 std::cerr << "[LLMClient] Streaming cancelled\n";
183 return false;
184 }
185 }
186 else
187 {
188 response_buffer += chunk_str;
189 }
190 return true;
191 };
192
193 const int max_delay = 30;
194 bool request_sent;
195 for (int attempt = 0; attempt <= max_retries; attempt++)
196 {
197 request_sent = use_ssl ? sslClient->send(req) : client->send(req);
198 if (request_sent || *cancel_flag) break;
199
200 int delay_seconds = std::min(max_delay, 1 << attempt);
201 std::cerr << "[LLMClient] POST failed, retrying in " << delay_seconds
202 << "s (attempt " << attempt << "/" << max_retries << ")\n";
203 std::this_thread::sleep_for(std::chrono::seconds(delay_seconds));
204 }
205
206 if (!request_sent)
207 {
208 std::cerr << "[LLMClient] POST request failed after retries\n";
209 return "{}";
210 }
211
212 if (stream) active_requests.erase(std::remove(active_requests.begin(), active_requests.end(), cancel_flag), active_requests.end());
213 delete cancel_flag;
214
215 if (stream) {
216 return concatenator.get_result_json();
217 } else {
218 return response_buffer;
219 }
220#endif
221}
222
223//================ LLMClient ================//
224
225// Constructor for local LLM
227
228// Constructor for remote LLM
229LLMClient::LLMClient(const std::string &url_, const int port_, const std::string &API_key_, const int max_retries_) : url(url_), port(port_), API_key(API_key_), max_retries(max_retries_)
230{
231 std::string host;
232 if (url.rfind("https://", 0) == 0)
233 {
234 host = url.substr(8);
235 use_ssl = true;
236 }
237 else
238 {
239 host = url.rfind("http://", 0) == 0 ? url.substr(7) : url;
240 use_ssl = false;
241 }
242
243#if TARGET_OS_IOS || TARGET_OS_VISION
244 transport = new IOSHttpTransport(host, use_ssl, port);
245 transport->set_timeout(60.0);
246#else
247 if (use_ssl)
248 {
249 sslClient = new httplib::SSLClient(host.c_str(), port);
250 }
251 else
252 {
253 client = new httplib::Client(host.c_str(), port);
254 }
255#endif
256}
257
259{
260#if TARGET_OS_IOS || TARGET_OS_VISION
261 if (transport != nullptr) {
262 delete transport;
263 }
264#else
265 if (client != nullptr)
266 delete client;
267 if (sslClient != nullptr)
268 delete sslClient;
269#endif
270}
271
272void LLMClient::set_SSL(const char *SSL_cert_)
273{
274#if !(TARGET_OS_IOS || TARGET_OS_VISION)
275 if (is_remote())
276 {
277 this->SSL_cert = SSL_cert_;
278 if (sslClient != nullptr)
279 sslClient->set_ca_cert_store(load_client_cert(SSL_cert));
280 }
281#endif
282}
283
284std::string LLMClient::tokenize_json(const json &data)
285{
286 if (is_remote())
287 {
288 return post_request("tokenize", data);
289 }
290 else
291 {
292 return llm->tokenize_json(data);
293 }
294}
295
296std::string LLMClient::detokenize_json(const json &data)
297{
298 if (is_remote())
299 {
300 return post_request("detokenize", data);
301 }
302 else
303 {
304 return llm->detokenize_json(data);
305 }
306}
307
308std::string LLMClient::embeddings_json(const json &data)
309{
310 if (is_remote())
311 {
312 return post_request("embeddings", data);
313 }
314 else
315 {
316 return llm->embeddings_json(data);
317 }
318}
319
320std::string LLMClient::completion_json(const json &data, CharArrayFn callback, bool callbackWithJSON)
321{
322 if (is_remote())
323 {
324 json data_remote = data;
325 if (data.contains("id_slot") && data["id_slot"] != -1)
326 {
327 std::cerr << "Remote clients can only use id_slot -1" << std::endl;
328 data_remote["id_slot"] = -1;
329 }
330 return post_request("completion", data_remote, callback, callbackWithJSON);
331 }
332 else
333 {
334 return llm->completion_json(data, callback, callbackWithJSON);
335 }
336}
337
339{
340 if (is_remote())
341 return -1;
342 return llm->get_next_available_slot();
343}
344
345std::string LLMClient::apply_template_json(const json &data)
346{
347 if (is_remote())
348 {
349 return post_request("apply-template", data);
350 }
351 else
352 {
353 return llm->apply_template_json(data);
354 }
355}
356
357std::string LLMClient::slot_json(const json &data)
358{
359 if (is_remote())
360 {
361 std::cerr << "Slot operations are not supported in remote clients" << std::endl;
362 return "{}";
363 }
364 else
365 {
366 return llm->slot_json(data);
367 }
368}
369
370void LLMClient::cancel(int id_slot)
371{
372 if (is_remote())
373 {
374 for (bool* flag : active_requests) *flag = true;
375 }
376 else
377 {
378 llm->cancel(id_slot);
379 }
380}
381
382//================ API ================//
383
384bool LLMClient_Is_Server_Alive(LLMClient *llm)
385{
386 return llm->is_server_alive();
387}
388
389void LLMClient_Set_SSL(LLMClient *llm, const char *SSL_cert)
390{
391 llm->set_SSL(SSL_cert);
392}
393
395{
396 return new LLMClient(llm);
397}
398
399LLMClient *LLMClient_Construct_Remote(const char *url, const int port, const char *API_key)
400{
401 return new LLMClient(url, port, API_key);
402}
Client interface for local and remote LLM access.
Client for accessing LLM functionality locally or remotely.
Definition LLM_client.h:32
int get_next_available_slot() override
Get available processing slot (override)
void set_SSL(const char *SSL_cert)
Configure SSL certificate for remote connections.
std::string slot_json(const json &data) override
Manage slots with HTTP response support.
LLMClient(LLMProvider *llm)
Constructor for local LLM access.
~LLMClient()
Destructor.
std::string apply_template_json(const json &data) override
Apply a chat template to message data.
bool is_remote() const
Check if this is a remote client.
Definition LLM_client.h:60
std::string detokenize_json(const json &data) override
Convert tokens back to text.
std::string completion_json(const json &data, CharArrayFn callback=nullptr, bool callbackWithJSON=true) override
Generate text completion (override)
std::string embeddings_json(const json &data) override
Generate embeddings with HTTP response support.
std::string tokenize_json(const json &data) override
Tokenize input (override)
void cancel(int id_slot) override
Cancel running request (override)
virtual std::string slot_json(const json &data)=0
Manage slots with HTTP response support.
virtual int get_next_available_slot()=0
Get an available processing slot.
virtual void cancel(int id_slot)=0
Cancel request.
Abstract class for LLM service providers.
Definition LLM.h:275
virtual std::string embeddings_json(const json &data)=0
Generate embeddings with HTTP response support.
virtual std::string apply_template_json(const json &data)=0
Apply a chat template to message data.
virtual std::string tokenize_json(const json &data)=0
Tokenize input (override)
virtual std::string completion_json(const json &data, CharArrayFn callback, bool callbackWithJSON)=0
Generate text completion.
virtual std::string detokenize_json(const json &data)=0
Convert tokens back to text.
Handles concatenation of LLM response chunks (both streaming and non-streaming) Accumulates content a...
bool process_chunk(const std::string &chunk_data)
Process a single chunk and accumulate its content/tokens.
std::string get_result_json() const
Get the complete result as JSON string.
void set_callback(CharArrayFn callback, bool callWithJSON=false)
Set a callback to be invoked after each chunk is processed.
LLMClient * LLMClient_Construct(LLMProvider *llm)
Construct local LLMClient (C API)
void LLMClient_Set_SSL(LLMClient *llm, const char *SSL_cert)
Set SSL certificate (C API)
LLMClient * LLMClient_Construct_Remote(const char *url, const int port, const char *API_key="")
Construct remote LLMClient (C API)