LLM for Unity  v3.0.0
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMClient.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading.Tasks;
7using UndreamAI.LlamaLib;
8using UnityEngine;
9using Newtonsoft.Json.Linq;
10using System.Threading;
11
12namespace LLMUnity
13{
20 public class LLMClient : MonoBehaviour
21 {
22 #region Inspector Fields
24 [Tooltip("Show/hide advanced options in the inspector")]
25 [HideInInspector] public bool advancedOptions = false;
26
28 [Tooltip("Use remote LLM server instead of local instance")]
29 [LocalRemote, SerializeField] protected bool _remote;
30
32 [Tooltip("Local LLM GameObject to connect to")]
33 [Local, SerializeField] protected LLM _llm;
34
36 [Tooltip("API key for remote server authentication")]
37 [Remote, SerializeField] protected string _APIKey;
38
40 [Tooltip("Hostname or IP address of remote LLM server")]
41 [Remote, SerializeField] protected string _host = "localhost";
42
44 [Tooltip("Port number of remote LLM server")]
45 [Remote, SerializeField] protected int _port = 13333;
46
48 [Tooltip("Number of retries of remote LLM server")]
49 [Remote, SerializeField] protected int numRetries = 5;
50
52 [Tooltip("Grammar constraints for output formatting (GBNF or JSON schema format)")]
53 [ModelAdvanced, TextArea(1, 10), SerializeField] protected string _grammar = "";
54
55 // Completion Parameters
57 [Tooltip("Maximum tokens to generate (-1 = unlimited)")]
58 [Model] public int numPredict = -1;
59
61 [Tooltip("Cache processed prompts to speed up subsequent requests")]
62 [ModelAdvanced] public bool cachePrompt = true;
63
65 [Tooltip("Random seed for reproducible generation (0 = random)")]
66 [ModelAdvanced] public int seed = 0;
67
69 [Tooltip("Sampling temperature (0.0 = deterministic, higher = more creative)")]
70 [ModelAdvanced, Range(0f, 2f)] public float temperature = 0.2f;
71
73 [Tooltip("Top-k sampling: limit to k most likely tokens (0 = disabled)")]
74 [ModelAdvanced, Range(0, 100)] public int topK = 40;
75
77 [Tooltip("Top-p (nucleus) sampling: cumulative probability threshold (1.0 = disabled)")]
78 [ModelAdvanced, Range(0f, 1f)] public float topP = 0.9f;
79
81 [Tooltip("Minimum probability threshold for token selection")]
82 [ModelAdvanced, Range(0f, 1f)] public float minP = 0.05f;
83
85 [Tooltip("Penalty for repeated tokens (1.0 = no penalty)")]
86 [ModelAdvanced, Range(0f, 2f)] public float repeatPenalty = 1.1f;
87
89 [Tooltip("Presence penalty: reduce likelihood of any repeated token (0.0 = disabled)")]
90 [ModelAdvanced, Range(0f, 1f)] public float presencePenalty = 0f;
91
93 [Tooltip("Frequency penalty: reduce likelihood based on token frequency (0.0 = disabled)")]
94 [ModelAdvanced, Range(0f, 1f)] public float frequencyPenalty = 0f;
95
97 [Tooltip("Locally typical sampling strength (1.0 = disabled)")]
98 [ModelAdvanced, Range(0f, 1f)] public float typicalP = 1f;
99
101 [Tooltip("Number of recent tokens to consider for repetition penalty (0 = disabled, -1 = context size)")]
102 [ModelAdvanced, Range(0, 2048)] public int repeatLastN = 64;
103
105 [Tooltip("Mirostat sampling mode (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)")]
106 [ModelAdvanced, Range(0, 2)] public int mirostat = 0;
107
109 [Tooltip("Mirostat target entropy (tau) - balance between coherence and diversity")]
110 [ModelAdvanced, Range(0f, 10f)] public float mirostatTau = 5f;
111
113 [Tooltip("Mirostat learning rate (eta) - adaptation speed")]
114 [ModelAdvanced, Range(0f, 1f)] public float mirostatEta = 0.1f;
115
117 [Tooltip("Include top N token probabilities in response (0 = disabled)")]
118 [ModelAdvanced, Range(0, 10)] public int nProbs = 0;
119
121 [Tooltip("Ignore end-of-stream token and continue generating")]
122 [ModelAdvanced] public bool ignoreEos = false;
123 #endregion
124
125 #region Public Properties
127 public bool remote
128 {
129 get => _remote;
130 set
131 {
132 if (_remote != value)
133 {
134 _remote = value;
135 if (started) _ = SetupCaller();
136 }
137 }
138 }
139
141 public LLM llm
142 {
143 get => _llm;
144 set => _ = SetLLM(value);
145 }
146
148 public string APIKey
149 {
150 get => _APIKey;
151 set
152 {
153 if (_APIKey != value)
154 {
155 _APIKey = value;
156 if (started) _ = SetupCaller();
157 }
158 }
159 }
160
162 public string host
163 {
164 get => _host;
165 set
166 {
167 if (_host != value)
168 {
169 _host = value;
170 if (started) _ = SetupCaller();
171 }
172 }
173 }
174
176 public int port
177 {
178 get => _port;
179 set
180 {
181 if (_port != value)
182 {
183 _port = value;
184 if (started) _ = SetupCaller();
185 }
186 }
187 }
188
190 public string grammar
191 {
192 get => _grammar;
193 set => SetGrammar(value);
194 }
195
196 #endregion
197
198 #region Private Fields
199 protected UndreamAI.LlamaLib.LLMClient llmClient;
200 private bool started = false;
201 private string completionParametersCache = "";
202 private readonly SemaphoreSlim startSemaphore = new SemaphoreSlim(1, 1);
203 #endregion
204
205 #region Unity Lifecycle
209 public virtual void Awake()
210 {
211 if (!enabled) return;
212
213 if (!remote)
214 {
215 AssignLLM();
216 if (llm == null) LLMUnitySetup.LogError($"No LLM assigned or detected for {GetType().Name} '{name}'!", true);
217 }
218 }
219
223 public virtual async void Start()
224 {
225 if (!enabled) return;
226 await SetupCaller();
227 started = true;
228 }
229
230 protected virtual void OnValidate()
231 {
232 AssignLLM();
233 }
234
235 protected virtual void Reset()
236 {
237 AssignLLM();
238 }
239
240 #endregion
241
242 #region Initialization
243 protected virtual async Task CheckCaller(bool checkConnection = true)
244 {
245 await startSemaphore.WaitAsync();
246 startSemaphore.Release();
247 if (GetCaller() == null) LLMUnitySetup.LogError("LLM caller not initialized", true);
248 if (remote && checkConnection)
249 {
250 for (int attempt = 0; attempt <= numRetries; attempt++)
251 {
252 if (llmClient.IsServerAlive()) break;
253 await Task.Yield();
254 }
255 }
256 }
257
261 protected virtual async Task SetupCaller()
262 {
263 await SetupCallerObject();
264 await PostSetupCallerObject();
265 }
266
270 protected virtual async Task SetupCallerObject()
271 {
272 await startSemaphore.WaitAsync();
273
274 string exceptionMessage = "";
275 try
276 {
277 if (!remote)
278 {
279 if (llm != null) await llm.WaitUntilReady();
280 if (llm?.llmService == null) LLMUnitySetup.LogError("Local LLM service is not available", true);
281 llmClient = new UndreamAI.LlamaLib.LLMClient(llm.llmService);
282 }
283 else
284 {
285 llmClient = new UndreamAI.LlamaLib.LLMClient(host, port, APIKey, numRetries);
286 }
287 }
288 catch (Exception ex)
289 {
290 LLMUnitySetup.LogError(ex.Message);
291 exceptionMessage = ex.Message;
292 }
293 finally
294 {
295 startSemaphore.Release();
296 }
297
298 if (llmClient == null || exceptionMessage != "")
299 {
300 string error = "llmClient not initialized";
301 if (exceptionMessage != "") error += ", error: " + exceptionMessage;
302 LLMUnitySetup.LogError(error, true);
303 }
304 }
305
309 protected virtual async Task PostSetupCallerObject()
310 {
312 completionParametersCache = "";
313 await Task.Yield();
314 }
315
319 protected virtual LLMLocal GetCaller()
320 {
321 return llmClient;
322 }
323
328 protected virtual async Task SetLLM(LLM llmInstance)
329 {
330 if (llmInstance == _llm) return;
331
332 if (remote)
333 {
334 LLMUnitySetup.LogError("Cannot set LLM when client is in remote mode");
335 return;
336 }
337
338 _llm = llmInstance;
339 if (started) await SetupCaller();
340 }
341
342 #endregion
343
344 #region LLM Assignment
351 public virtual bool IsAutoAssignableLLM(LLM llmInstance)
352 {
353 return true;
354 }
355
359 protected virtual void AssignLLM()
360 {
361 if (remote || llm != null) return;
362
363 var validLLMs = new List<LLM>();
364
365#if UNITY_6000_0_OR_NEWER
366 foreach (LLM foundLlm in FindObjectsByType<LLM>(FindObjectsSortMode.None))
367#else
368 foreach (LLM foundLlm in FindObjectsOfType<LLM>())
369#endif
370 {
371 if (IsAutoAssignableLLM(foundLlm))
372 {
373 validLLMs.Add(foundLlm);
374 }
375 }
376
377 if (validLLMs.Count == 0) return;
378
379 llm = SortLLMsByBestMatch(validLLMs.ToArray())[0];
380
381 string message = $"Auto-assigned LLM '{llm.name}' to {GetType().Name} '{name}'";
382 if (llm.gameObject.scene != gameObject.scene)
383 {
384 message += $" (from scene '{llm.gameObject.scene.name}')";
385 }
386 LLMUnitySetup.Log(message);
387 }
388
392 protected virtual LLM[] SortLLMsByBestMatch(LLM[] llmArray)
393 {
394 LLM[] array = (LLM[])llmArray.Clone();
395 for (int i = 0; i < array.Length - 1; i++)
396 {
397 bool swapped = false;
398 for (int j = 0; j < array.Length - i - 1; j++)
399 {
400 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
401 bool swap = (
402 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
403 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
404 );
405 if (swap)
406 {
407 LLM temp = array[j];
408 array[j] = array[j + 1];
409 array[j + 1] = temp;
410 swapped = true;
411 }
412 }
413 if (!swapped) break;
414 }
415 return array;
416 }
417
418 #endregion
419
420 #region Grammar Management
425 public virtual void SetGrammar(string grammarString)
426 {
427 _grammar = grammarString ?? "";
428 GetCaller()?.SetGrammar(_grammar);
429 }
430
435 public virtual void LoadGrammar(string path)
436 {
437 if (string.IsNullOrEmpty(path)) return;
438
439 if (!File.Exists(path))
440 {
441 LLMUnitySetup.LogError($"Grammar file not found: {path}");
442 return;
443 }
444
445 try
446 {
447 string grammarContent = File.ReadAllText(path);
448 SetGrammar(grammarContent);
449 LLMUnitySetup.Log($"Loaded grammar from: {path}");
450 }
451 catch (Exception ex)
452 {
453 LLMUnitySetup.LogError($"Failed to load grammar file '{path}': {ex.Message}");
454 }
455 }
456
457 #endregion
458
459 #region Completion Parameters
464 protected virtual void SetCompletionParameters()
465 {
466 if (llm != null && llm.embeddingsOnly)
467 {
468 string error = "LLM can't be used for completion, it is an embeddings only model!";
469 LLMUnitySetup.LogError(error, true);
470 }
471
472 var parameters = new JObject
473 {
474 ["temperature"] = temperature,
475 ["top_k"] = topK,
476 ["top_p"] = topP,
477 ["min_p"] = minP,
478 ["n_predict"] = numPredict,
479 ["typical_p"] = typicalP,
480 ["repeat_penalty"] = repeatPenalty,
481 ["repeat_last_n"] = repeatLastN,
482 ["presence_penalty"] = presencePenalty,
483 ["frequency_penalty"] = frequencyPenalty,
484 ["mirostat"] = mirostat,
485 ["mirostat_tau"] = mirostatTau,
486 ["mirostat_eta"] = mirostatEta,
487 ["seed"] = seed,
488 ["ignore_eos"] = ignoreEos,
489 ["n_probs"] = nProbs,
490 ["cache_prompt"] = cachePrompt
491 };
492
493 string parametersJson = parameters.ToString();
494 if (parametersJson != completionParametersCache)
495 {
496 GetCaller()?.SetCompletionParameters(parameters);
497 completionParametersCache = parametersJson;
498 }
499 }
500
501 #endregion
502
503 #region Core LLM Operations
510 public virtual async Task<List<int>> Tokenize(string query, Action<List<int>> callback = null)
511 {
512 if (string.IsNullOrEmpty(query))
513 {
514 LLMUnitySetup.LogError("query is null", true);
515 }
516 await CheckCaller();
517
518 List<int> tokens = llmClient.Tokenize(query);
519 callback?.Invoke(tokens);
520 return tokens;
521 }
522
529 public virtual async Task<string> Detokenize(List<int> tokens, Action<string> callback = null)
530 {
531 if (tokens == null)
532 {
533 LLMUnitySetup.LogError("tokens is null", true);
534 }
535 await CheckCaller();
536
537 string text = llmClient.Detokenize(tokens);
538 callback?.Invoke(text);
539 return text;
540 }
541
548 public virtual async Task<List<float>> Embeddings(string query, Action<List<float>> callback = null)
549 {
550 if (string.IsNullOrEmpty(query))
551 {
552 LLMUnitySetup.LogError("query is null", true);
553 }
554 await CheckCaller();
555
556 List<float> embeddings;
557 if (!llm.embeddingsOnly)
558 {
559 LLMUnitySetup.LogError("You need to use an embedding model for embeddings (see \"RAG models\" in \"Download model\")");
560 embeddings = new List<float>();
561 }
562 else
563 {
564 embeddings = llmClient.Embeddings(query);
565 }
566 if (embeddings.Count == 0) LLMUnitySetup.LogError("embeddings are empty!");
567 callback?.Invoke(embeddings);
568 return embeddings;
569 }
570
579 public virtual async Task<string> Completion(string prompt, Action<string> callback = null,
580 Action completionCallback = null, int id_slot = -1)
581 {
582 await CheckCaller();
583
584 LlamaLib.CharArrayCallback wrappedCallback = null;
585 if (callback != null)
586 {
587#if ENABLE_IL2CPP
588 Action<string> mainThreadCallback = Utils.WrapActionForMainThread(callback, this);
589 wrappedCallback = IL2CPP_Completion.CreateCallback(mainThreadCallback);
590#else
591 wrappedCallback = Utils.WrapCallbackForAsync(callback, this);
592#endif
593 }
594
595 SetCompletionParameters();
596 string result = await llmClient.CompletionAsync(prompt, wrappedCallback, id_slot);
597 completionCallback?.Invoke();
598 return result;
599 }
600
605 public void CancelRequest(int id_slot)
606 {
607 llmClient?.Cancel(id_slot);
608 }
609
610 #endregion
611 }
612}
Unity MonoBehaviour base class for LLM client functionality. Handles both local and remote LLM connec...
Definition LLMClient.cs:21
virtual async Task< string > Completion(string prompt, Action< string > callback=null, Action completionCallback=null, int id_slot=-1)
Generates text completion.
Definition LLMClient.cs:579
int mirostat
Mirostat sampling mode (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)
Definition LLMClient.cs:106
float topP
Top-p (nucleus) sampling: cumulative probability threshold (1.0 = disabled)
Definition LLMClient.cs:78
string APIKey
API key for remote server authentication.
Definition LLMClient.cs:149
bool ignoreEos
Ignore end-of-stream token and continue generating.
Definition LLMClient.cs:122
virtual void LoadGrammar(string path)
Loads grammar constraints from a file.
Definition LLMClient.cs:435
int port
Remote server port number.
Definition LLMClient.cs:177
bool cachePrompt
Cache processed prompts to speed up subsequent requests.
Definition LLMClient.cs:62
virtual async Task< List< float > > Embeddings(string query, Action< List< float > > callback=null)
Generates embedding vectors for the input text.
Definition LLMClient.cs:548
int repeatLastN
Number of recent tokens to consider for repetition penalty (0 = disabled, -1 = context size)
Definition LLMClient.cs:102
float frequencyPenalty
Frequency penalty: reduce likelihood based on token frequency (0.0 = disabled)
Definition LLMClient.cs:94
float minP
Minimum probability threshold for token selection.
Definition LLMClient.cs:82
float temperature
Sampling temperature (0.0 = deterministic, higher = more creative)
Definition LLMClient.cs:70
string grammar
Current grammar constraints for output formatting.
Definition LLMClient.cs:191
bool advancedOptions
Show/hide advanced options in the inspector.
Definition LLMClient.cs:25
bool remote
Whether this client uses a remote server connection.
Definition LLMClient.cs:128
virtual bool IsAutoAssignableLLM(LLM llmInstance)
Determines if an LLM instance can be auto-assigned to this client. Override in derived classes to imp...
Definition LLMClient.cs:351
virtual async Task< string > Detokenize(List< int > tokens, Action< string > callback=null)
Converts token IDs back to text.
Definition LLMClient.cs:529
float mirostatEta
Mirostat learning rate (eta) - adaptation speed.
Definition LLMClient.cs:114
string host
Remote server hostname or IP address.
Definition LLMClient.cs:163
int seed
Random seed for reproducible generation (0 = random)
Definition LLMClient.cs:66
int topK
Top-k sampling: limit to k most likely tokens (0 = disabled)
Definition LLMClient.cs:74
float presencePenalty
Presence penalty: reduce likelihood of any repeated token (0.0 = disabled)
Definition LLMClient.cs:90
int numPredict
Maximum tokens to generate (-1 = unlimited)
Definition LLMClient.cs:58
virtual void Awake()
Unity Awake method that validates configuration and assigns local LLM if needed.
Definition LLMClient.cs:209
LLM llm
The local LLM instance (null if using remote)
Definition LLMClient.cs:142
float typicalP
Locally typical sampling strength (1.0 = disabled)
Definition LLMClient.cs:98
float repeatPenalty
Penalty for repeated tokens (1.0 = no penalty)
Definition LLMClient.cs:86
virtual async void Start()
Unity Start method that initializes the LLM client connection.
Definition LLMClient.cs:223
int nProbs
Include top N token probabilities in response (0 = disabled)
Definition LLMClient.cs:118
virtual void SetGrammar(string grammarString)
Sets grammar constraints for structured output generation.
Definition LLMClient.cs:425
float mirostatTau
Mirostat target entropy (tau) - balance between coherence and diversity.
Definition LLMClient.cs:110
void CancelRequest(int id_slot)
Cancels an active request in the specified slot.
Definition LLMClient.cs:605
virtual async Task< List< int > > Tokenize(string query, Action< List< int > > callback=null)
Converts text into a list of token IDs.
Definition LLMClient.cs:510
Class implementing helper functions for setup and process management.
Unity MonoBehaviour component that manages a local LLM server instance. Handles model loading,...
Definition LLM.cs:21
LLMService llmService
The underlying LLM service instance.
Definition LLM.cs:298
bool embeddingsOnly
True if this model only supports embeddings (no text generation)
Definition LLM.cs:306
async Task WaitUntilReady()
Waits asynchronously until the LLM is ready to accept requests.
Definition LLM.cs:540