LLM for Unity  v2.4.2
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMCaller.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.Threading.Tasks;
6using UnityEngine;
7using UnityEngine.Networking;
8
9namespace LLMUnity
10{
11 [DefaultExecutionOrder(-2)]
16 public class LLMCaller : MonoBehaviour
17 {
19 [Tooltip("show/hide advanced options in the GameObject")]
20 [HideInInspector] public bool advancedOptions = false;
22 [Tooltip("use remote LLM server")]
23 [LocalRemote] public bool remote = false;
25 [Tooltip("LLM GameObject to use")] // Tooltip: ignore
26 [Local, SerializeField] protected LLM _llm;
27 public LLM llm
28 {
29 get => _llm;//whatever
30 set => SetLLM(value);
31 }
33 [Tooltip("API key for the remote server")]
34 [Remote] public string APIKey;
36 [Tooltip("host of the remote LLM server")]
37 [Remote] public string host = "localhost";
39 [Tooltip("port of the remote LLM server")]
40 [Remote] public int port = 13333;
42 [Tooltip("number of retries to use for the remote LLM server requests (-1 = infinite)")]
43 [Remote] public int numRetries = 10;
44
45 protected LLM _prellm;
46 protected List<(string, string)> requestHeaders;
47 protected List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
48
58 public virtual void Awake()
59 {
60 // Start the LLM server in a cross-platform way
61 if (!enabled) return;
62
63 requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
64 if (!remote)
65 {
66 AssignLLM();
67 if (llm == null)
68 {
69 string error = $"No LLM assigned or detected for LLMCharacter {name}!";
70 LLMUnitySetup.LogError(error);
71 throw new Exception(error);
72 }
73 }
74 else
75 {
76 if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey));
77 }
78 }
79
84 protected virtual void SetLLM(LLM llmSet)
85 {
86 if (llmSet != null && !IsValidLLM(llmSet))
87 {
88 LLMUnitySetup.LogError(NotValidLLMError());
89 llmSet = null;
90 }
91 _llm = llmSet;
92 _prellm = _llm;
93 }
94
100 public virtual bool IsValidLLM(LLM llmSet)
101 {
102 return true;
103 }
104
110 public virtual bool IsAutoAssignableLLM(LLM llmSet)
111 {
112 return true;
113 }
114
115 protected virtual string NotValidLLMError()
116 {
117 return $"Can't set LLM {llm.name} to {name}";
118 }
119
120 protected virtual void OnValidate()
121 {
122 if (_llm != _prellm) SetLLM(_llm);
123 AssignLLM();
124 }
125
126 protected virtual void Reset()
127 {
128 AssignLLM();
129 }
130
131 protected virtual void AssignLLM()
132 {
133 if (remote || llm != null) return;
134
135 List<LLM> validLLMs = new List<LLM>();
136#if UNITY_6000_0_OR_NEWER
137 foreach (LLM foundllm in FindObjectsByType(typeof(LLM), FindObjectsSortMode.None))
138#else
139 foreach (LLM foundllm in FindObjectsOfType<LLM>())
140#endif
141 {
142 if (IsValidLLM(foundllm) && IsAutoAssignableLLM(foundllm)) validLLMs.Add(foundllm);
143 }
144 if (validLLMs.Count == 0) return;
145
146 llm = SortLLMsByBestMatching(validLLMs.ToArray())[0];
147 string msg = $"Assigning LLM {llm.name} to {GetType()} {name}";
148 if (llm.gameObject.scene != gameObject.scene) msg += $" from scene {llm.gameObject.scene}";
149 LLMUnitySetup.Log(msg);
150 }
151
152 protected virtual LLM[] SortLLMsByBestMatching(LLM[] arrayIn)
153 {
154 LLM[] array = (LLM[])arrayIn.Clone();
155 for (int i = 0; i < array.Length - 1; i++)
156 {
157 bool swapped = false;
158 for (int j = 0; j < array.Length - i - 1; j++)
159 {
160 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
161 bool swap = (
162 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
163 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
164 );
165 if (swap)
166 {
167 LLM temp = array[j];
168 array[j] = array[j + 1];
169 array[j + 1] = temp;
170 swapped = true;
171 }
172 }
173 if (!swapped) break;
174 }
175 return array;
176 }
177
178 protected virtual List<int> TokenizeContent(TokenizeResult result)
179 {
180 // get the tokens from a tokenize result received from the endpoint
181 return result.tokens;
182 }
183
184 protected virtual string DetokenizeContent(TokenizeRequest result)
185 {
186 // get content from a chat result received from the endpoint
187 return result.content;
188 }
189
190 protected virtual List<float> EmbeddingsContent(EmbeddingsResult result)
191 {
192 // get content from a chat result received from the endpoint
193 return result.embedding;
194 }
195
196 protected virtual Ret ConvertContent<Res, Ret>(string response, ContentCallback<Res, Ret> getContent = null)
197 {
198 // template function to convert the json received and get the content
199 if (response == null) return default;
200 response = response.Trim();
201 if (response.StartsWith("data: "))
202 {
203 string responseArray = "";
204 foreach (string responsePart in response.Replace("\n\n", "").Split("data: "))
205 {
206 if (responsePart == "") continue;
207 if (responseArray != "") responseArray += ",\n";
208 responseArray += responsePart;
209 }
210 response = $"{{\"data\": [{responseArray}]}}";
211 }
212 return getContent(JsonUtility.FromJson<Res>(response));
213 }
214
215 protected virtual void CancelRequestsLocal() {}
216
217 protected virtual void CancelRequestsRemote()
218 {
219 foreach (UnityWebRequest request in WIPRequests)
220 {
221 request.Abort();
222 }
223 WIPRequests.Clear();
224 }
225
229 // <summary>
230 public virtual void CancelRequests()
231 {
232 if (remote) CancelRequestsRemote();
233 else CancelRequestsLocal();
234 }
235
236 protected virtual async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
237 {
238 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
239 // this function has streaming functionality i.e. handles the answer while it is being received
240 while (!llm.failed && !llm.started) await Task.Yield();
241 string callResult = null;
242 switch (endpoint)
243 {
244 case "tokenize":
245 callResult = await llm.Tokenize(json);
246 break;
247 case "detokenize":
248 callResult = await llm.Detokenize(json);
249 break;
250 case "embeddings":
251 callResult = await llm.Embeddings(json);
252 break;
253 case "slots":
254 callResult = await llm.Slot(json);
255 break;
256 default:
257 LLMUnitySetup.LogError($"Unknown endpoint {endpoint}");
258 break;
259 }
260
261 Ret result = ConvertContent(callResult, getContent);
262 callback?.Invoke(result);
263 return result;
264 }
265
266 protected virtual async Task<Ret> PostRequestRemote<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
267 {
268 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
269 // this function has streaming functionality i.e. handles the answer while it is being received
270 if (endpoint == "slots")
271 {
272 LLMUnitySetup.LogError("Saving and loading is not currently supported in remote setting");
273 return default;
274 }
275
276 Ret result = default;
277 byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json);
278 UnityWebRequest request = null;
279 string error = null;
280 int tryNr = numRetries;
281
282 while (tryNr != 0)
283 {
284 using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
285 {
286 WIPRequests.Add(request);
287
288 request.method = "POST";
289 if (requestHeaders != null)
290 {
291 for (int i = 0; i < requestHeaders.Count; i++)
292 request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
293 }
294
295 // Start the request asynchronously
296 UnityWebRequestAsyncOperation asyncOperation = request.SendWebRequest();
297 await Task.Yield(); // Wait for the next frame so that asyncOperation is properly registered (especially if not in main thread)
298
299 float lastProgress = 0f;
300 // Continue updating progress until the request is completed
301 while (!asyncOperation.isDone)
302 {
303 float currentProgress = request.downloadProgress;
304 // Check if progress has changed
305 if (currentProgress != lastProgress && callback != null)
306 {
307 callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
308 lastProgress = currentProgress;
309 }
310 // Wait for the next frame
311 await Task.Yield();
312 }
313 WIPRequests.Remove(request);
314 if (request.result == UnityWebRequest.Result.Success)
315 {
316 result = ConvertContent(request.downloadHandler.text, getContent);
317 error = null;
318 break;
319 }
320 else
321 {
322 result = default;
323 error = request.error;
324 if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break;
325 }
326 }
327 tryNr--;
328 if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr));
329 }
330
331 if (error != null) LLMUnitySetup.LogError(error);
332 callback?.Invoke(result);
333 return result;
334 }
335
336 protected virtual async Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
337 {
338 if (remote) return await PostRequestRemote(json, endpoint, getContent, callback);
339 return await PostRequestLocal(json, endpoint, getContent, callback);
340 }
341
348 public virtual async Task<List<int>> Tokenize(string query, Callback<List<int>> callback = null)
349 {
350 // handle the tokenization of a message by the user
351 TokenizeRequest tokenizeRequest = new TokenizeRequest();
352 tokenizeRequest.content = query;
353 string json = JsonUtility.ToJson(tokenizeRequest);
354 return await PostRequest<TokenizeResult, List<int>>(json, "tokenize", TokenizeContent, callback);
355 }
356
363 public virtual async Task<string> Detokenize(List<int> tokens, Callback<string> callback = null)
364 {
365 // handle the detokenization of a message by the user
366 TokenizeResult tokenizeRequest = new TokenizeResult();
367 tokenizeRequest.tokens = tokens;
368 string json = JsonUtility.ToJson(tokenizeRequest);
369 return await PostRequest<TokenizeRequest, string>(json, "detokenize", DetokenizeContent, callback);
370 }
371
378 public virtual async Task<List<float>> Embeddings(string query, Callback<List<float>> callback = null)
379 {
380 // handle the tokenization of a message by the user
381 TokenizeRequest tokenizeRequest = new TokenizeRequest();
382 tokenizeRequest.content = query;
383 string json = JsonUtility.ToJson(tokenizeRequest);
384 return await PostRequest<EmbeddingsResult, List<float>>(json, "embeddings", EmbeddingsContent, callback);
385 }
386 }
387}
Class implementing calling of LLM functions (local and remote).
Definition LLMCaller.cs:17
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
Definition LLMCaller.cs:348
virtual bool IsValidLLM(LLM llmSet)
Checks if a LLM is valid for the LLMCaller.
Definition LLMCaller.cs:100
virtual async Task< List< float > > Embeddings(string query, Callback< List< float > > callback=null)
Computes the embeddings of the provided input.
Definition LLMCaller.cs:378
virtual void Awake()
The Unity Awake function that initializes the state before the application starts....
Definition LLMCaller.cs:58
int numRetries
number of retries to use for the remote LLM server requests (-1 = infinite)
Definition LLMCaller.cs:43
int port
port of the remote LLM server
Definition LLMCaller.cs:40
string host
host of the remote LLM server
Definition LLMCaller.cs:37
virtual void CancelRequests()
Cancel the ongoing requests e.g. Chat, Complete.
Definition LLMCaller.cs:230
virtual async Task< string > Detokenize(List< int > tokens, Callback< string > callback=null)
Detokenises the provided tokens to a string.
Definition LLMCaller.cs:363
bool remote
use remote LLM server
Definition LLMCaller.cs:23
bool advancedOptions
show/hide advanced options in the GameObject
Definition LLMCaller.cs:20
string APIKey
API key for the remote server.
Definition LLMCaller.cs:34
virtual bool IsAutoAssignableLLM(LLM llmSet)
Checks if a LLM can be auto-assigned if the LLM of the LLMCaller is null.
Definition LLMCaller.cs:110
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19
async Task< string > Slot(string json)
Allows to save / restore the state of a slot.
Definition LLM.cs:763
async Task< string > Detokenize(string json)
Detokenises the provided query.
Definition LLM.cs:691
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
Definition LLM.cs:53
async Task< string > Tokenize(string json)
Tokenises the provided query.
Definition LLM.cs:676
bool failed
Boolean set to true if the server has failed to start.
Definition LLM.cs:55
async Task< string > Embeddings(string json)
Computes the embeddings of the provided query.
Definition LLM.cs:706