22 #region Inspector Fields
24 [Tooltip(
"Show/hide advanced options in the inspector")]
28 [Tooltip(
"Use remote LLM server instead of local instance")]
29 [LocalRemote, SerializeField]
protected bool _remote;
32 [Tooltip(
"Local LLM GameObject to connect to")]
33 [Local, SerializeField]
protected LLM _llm;
36 [Tooltip(
"API key for remote server authentication")]
37 [Remote, SerializeField]
protected string _APIKey;
40 [Tooltip(
"Hostname or IP address of remote LLM server")]
41 [Remote, SerializeField]
protected string _host =
"localhost";
44 [Tooltip(
"Port number of remote LLM server")]
45 [Remote, SerializeField]
protected int _port = 13333;
48 [Tooltip(
"Number of retries of remote LLM server")]
49 [Remote, SerializeField]
protected int numRetries = 5;
52 [Tooltip(
"Grammar constraints for output formatting (GBNF or JSON schema format)")]
53 [ModelAdvanced, TextArea(1, 10), SerializeField]
protected string _grammar =
"";
57 [Tooltip(
"Maximum tokens to generate (-1 = unlimited)")]
61 [Tooltip(
"Cache processed prompts to speed up subsequent requests")]
65 [Tooltip(
"Random seed for reproducible generation (0 = random)")]
66 [ModelAdvanced]
public int seed = 0;
69 [Tooltip(
"Sampling temperature (0.0 = deterministic, higher = more creative)")]
70 [ModelAdvanced, Range(0f, 2f)]
public float temperature = 0.2f;
73 [Tooltip(
"Top-k sampling: limit to k most likely tokens (0 = disabled)")]
74 [ModelAdvanced, Range(0, 100)]
public int topK = 40;
77 [Tooltip(
"Top-p (nucleus) sampling: cumulative probability threshold (1.0 = disabled)")]
78 [ModelAdvanced, Range(0f, 1f)]
public float topP = 0.9f;
81 [Tooltip(
"Minimum probability threshold for token selection")]
82 [ModelAdvanced, Range(0f, 1f)]
public float minP = 0.05f;
85 [Tooltip(
"Penalty for repeated tokens (1.0 = no penalty)")]
89 [Tooltip(
"Presence penalty: reduce likelihood of any repeated token (0.0 = disabled)")]
93 [Tooltip(
"Frequency penalty: reduce likelihood based on token frequency (0.0 = disabled)")]
97 [Tooltip(
"Locally typical sampling strength (1.0 = disabled)")]
98 [ModelAdvanced, Range(0f, 1f)]
public float typicalP = 1f;
101 [Tooltip(
"Number of recent tokens to consider for repetition penalty (0 = disabled, -1 = context size)")]
105 [Tooltip(
"Mirostat sampling mode (0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)")]
106 [ModelAdvanced, Range(0, 2)]
public int mirostat = 0;
109 [Tooltip(
"Mirostat target entropy (tau) - balance between coherence and diversity")]
113 [Tooltip(
"Mirostat learning rate (eta) - adaptation speed")]
117 [Tooltip(
"Include top N token probabilities in response (0 = disabled)")]
118 [ModelAdvanced, Range(0, 10)]
public int nProbs = 0;
121 [Tooltip(
"Ignore end-of-stream token and continue generating")]
125 #region Public Properties
132 if (_remote != value)
135 if (started) _ = SetupCaller();
144 set => _ = SetLLM(value);
153 if (_APIKey != value)
156 if (started) _ = SetupCaller();
170 if (started) _ = SetupCaller();
184 if (started) _ = SetupCaller();
198 #region Private Fields
199 protected UndreamAI.LlamaLib.LLMClient llmClient;
200 private bool started =
false;
201 private string completionParametersCache =
"";
202 private readonly SemaphoreSlim startSemaphore =
new SemaphoreSlim(1, 1);
205 #region Unity Lifecycle
211 if (!enabled)
return;
216 if (
llm ==
null)
LLMUnitySetup.LogError($
"No LLM assigned or detected for {GetType().Name} '{name}'!",
true);
225 if (!enabled)
return;
230 protected virtual void OnValidate()
235 protected virtual void Reset()
242 #region Initialization
243 protected virtual async Task CheckCaller(
bool checkConnection =
true)
245 await startSemaphore.WaitAsync();
246 startSemaphore.Release();
247 if (GetCaller() ==
null) LLMUnitySetup.LogError(
"LLM caller not initialized",
true);
248 if (
remote && checkConnection)
250 for (
int attempt = 0; attempt <= numRetries; attempt++)
252 if (llmClient.IsServerAlive())
break;
261 protected virtual async Task SetupCaller()
263 await SetupCallerObject();
264 await PostSetupCallerObject();
270 protected virtual async Task SetupCallerObject()
272 await startSemaphore.WaitAsync();
274 string exceptionMessage =
"";
280 if (
llm?.llmService ==
null) LLMUnitySetup.LogError(
"Local LLM service is not available",
true);
281 llmClient =
new UndreamAI.LlamaLib.LLMClient(
llm.
llmService);
285 llmClient =
new UndreamAI.LlamaLib.LLMClient(
host,
port,
APIKey, numRetries);
290 LLMUnitySetup.LogError(ex.Message);
291 exceptionMessage = ex.Message;
295 startSemaphore.Release();
298 if (llmClient ==
null || exceptionMessage !=
"")
300 string error =
"llmClient not initialized";
301 if (exceptionMessage !=
"") error +=
", error: " + exceptionMessage;
302 LLMUnitySetup.LogError(error,
true);
309 protected virtual async Task PostSetupCallerObject()
312 completionParametersCache =
"";
319 protected virtual LLMLocal GetCaller()
328 protected virtual async Task SetLLM(LLM llmInstance)
330 if (llmInstance == _llm)
return;
334 LLMUnitySetup.LogError(
"Cannot set LLM when client is in remote mode");
339 if (started) await SetupCaller();
344 #region LLM Assignment
359 protected virtual void AssignLLM()
363 var validLLMs =
new List<LLM>();
365#if UNITY_6000_0_OR_NEWER
366 foreach (
LLM foundLlm
in FindObjectsByType<LLM>(FindObjectsSortMode.None))
368 foreach (
LLM foundLlm
in FindObjectsOfType<LLM>())
373 validLLMs.Add(foundLlm);
377 if (validLLMs.Count == 0)
return;
379 llm = SortLLMsByBestMatch(validLLMs.ToArray())[0];
381 string message = $
"Auto-assigned LLM '{llm.name}' to {GetType().Name} '{name}'";
382 if (
llm.gameObject.scene != gameObject.scene)
384 message += $
" (from scene '{llm.gameObject.scene.name}')";
386 LLMUnitySetup.Log(message);
392 protected virtual LLM[] SortLLMsByBestMatch(LLM[] llmArray)
394 LLM[] array = (LLM[])llmArray.Clone();
395 for (
int i = 0; i < array.Length - 1; i++)
397 bool swapped =
false;
398 for (
int j = 0; j < array.Length - i - 1; j++)
400 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
402 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
403 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
408 array[j] = array[j + 1];
420 #region Grammar Management
427 _grammar = grammarString ??
"";
428 GetCaller()?.SetGrammar(_grammar);
437 if (
string.IsNullOrEmpty(path))
return;
439 if (!File.Exists(path))
447 string grammarContent = File.ReadAllText(path);
453 LLMUnitySetup.LogError($
"Failed to load grammar file '{path}': {ex.Message}");
459 #region Completion Parameters
464 protected virtual void SetCompletionParameters()
468 string error =
"LLM can't be used for completion, it is an embeddings only model!";
472 var parameters =
new JObject
493 string parametersJson = parameters.ToString();
494 if (parametersJson != completionParametersCache)
496 GetCaller()?.SetCompletionParameters(parameters);
497 completionParametersCache = parametersJson;
503 #region Core LLM Operations
510 public virtual async Task<List<int>>
Tokenize(
string query, Action<List<int>> callback =
null)
512 if (
string.IsNullOrEmpty(query))
518 List<int> tokens = llmClient.Tokenize(query);
519 callback?.Invoke(tokens);
529 public virtual async Task<string>
Detokenize(List<int> tokens, Action<string> callback =
null)
537 string text = llmClient.Detokenize(tokens);
538 callback?.Invoke(text);
548 public virtual async Task<List<float>>
Embeddings(
string query, Action<List<float>> callback =
null)
550 if (
string.IsNullOrEmpty(query))
556 List<float> embeddings;
559 LLMUnitySetup.LogError(
"You need to use an embedding model for embeddings (see \"RAG models\" in \"Download model\")");
560 embeddings =
new List<float>();
564 embeddings = llmClient.Embeddings(query);
566 if (embeddings.Count == 0)
LLMUnitySetup.LogError(
"embeddings are empty!");
567 callback?.Invoke(embeddings);
579 public virtual async Task<string>
Completion(
string prompt, Action<string> callback =
null,
580 Action completionCallback =
null,
int id_slot = -1)
584 LlamaLib.CharArrayCallback wrappedCallback =
null;
585 if (callback !=
null)
588 Action<string> mainThreadCallback =
Utils.WrapActionForMainThread(callback,
this);
589 wrappedCallback = IL2CPP_Completion.CreateCallback(mainThreadCallback);
591 wrappedCallback =
Utils.WrapCallbackForAsync(callback,
this);
595 SetCompletionParameters();
596 string result = await llmClient.CompletionAsync(prompt, wrappedCallback, id_slot);
597 completionCallback?.Invoke();
607 llmClient?.Cancel(id_slot);