LLM for Unity  v2.4.1
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMCaller.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.Threading.Tasks;
6using UnityEngine;
7using UnityEngine.Networking;
8
9namespace LLMUnity
10{
11 [DefaultExecutionOrder(-2)]
16 public class LLMCaller : MonoBehaviour
17 {
19 [HideInInspector] public bool advancedOptions = false;
21 [LocalRemote] public bool remote = false;
23 [Local, SerializeField] protected LLM _llm;
24 public LLM llm
25 {
26 get => _llm;//whatever
27 set => SetLLM(value);
28 }
29
31 [Remote] public string APIKey;
32
34 [Remote] public string host = "localhost";
36 [Remote] public int port = 13333;
38 [Remote] public int numRetries = 10;
39
40 protected LLM _prellm;
41 protected List<(string, string)> requestHeaders;
42 protected List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
43
53 public virtual void Awake()
54 {
55 // Start the LLM server in a cross-platform way
56 if (!enabled) return;
57
58 requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
59 if (!remote)
60 {
61 AssignLLM();
62 if (llm == null)
63 {
64 string error = $"No LLM assigned or detected for LLMCharacter {name}!";
65 LLMUnitySetup.LogError(error);
66 throw new Exception(error);
67 }
68 }
69 else
70 {
71 if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey));
72 }
73 }
74
79 protected virtual void SetLLM(LLM llmSet)
80 {
81 if (llmSet != null && !IsValidLLM(llmSet))
82 {
83 LLMUnitySetup.LogError(NotValidLLMError());
84 llmSet = null;
85 }
86 _llm = llmSet;
87 _prellm = _llm;
88 }
89
95 public virtual bool IsValidLLM(LLM llmSet)
96 {
97 return true;
98 }
99
105 public virtual bool IsAutoAssignableLLM(LLM llmSet)
106 {
107 return true;
108 }
109
110 protected virtual string NotValidLLMError()
111 {
112 return $"Can't set LLM {llm.name} to {name}";
113 }
114
115 protected virtual void OnValidate()
116 {
117 if (_llm != _prellm) SetLLM(_llm);
118 AssignLLM();
119 }
120
121 protected virtual void Reset()
122 {
123 AssignLLM();
124 }
125
126 protected virtual void AssignLLM()
127 {
128 if (remote || llm != null) return;
129
130 List<LLM> validLLMs = new List<LLM>();
131#if UNITY_6000_0_OR_NEWER
132 foreach (LLM foundllm in FindObjectsByType(typeof(LLM), FindObjectsSortMode.None))
133#else
134 foreach (LLM foundllm in FindObjectsOfType<LLM>())
135#endif
136 {
137 if (IsValidLLM(foundllm) && IsAutoAssignableLLM(foundllm)) validLLMs.Add(foundllm);
138 }
139 if (validLLMs.Count == 0) return;
140
141 llm = SortLLMsByBestMatching(validLLMs.ToArray())[0];
142 string msg = $"Assigning LLM {llm.name} to {GetType()} {name}";
143 if (llm.gameObject.scene != gameObject.scene) msg += $" from scene {llm.gameObject.scene}";
144 LLMUnitySetup.Log(msg);
145 }
146
147 protected virtual LLM[] SortLLMsByBestMatching(LLM[] arrayIn)
148 {
149 LLM[] array = (LLM[])arrayIn.Clone();
150 for (int i = 0; i < array.Length - 1; i++)
151 {
152 bool swapped = false;
153 for (int j = 0; j < array.Length - i - 1; j++)
154 {
155 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
156 bool swap = (
157 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
158 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
159 );
160 if (swap)
161 {
162 LLM temp = array[j];
163 array[j] = array[j + 1];
164 array[j + 1] = temp;
165 swapped = true;
166 }
167 }
168 if (!swapped) break;
169 }
170 return array;
171 }
172
173 protected virtual List<int> TokenizeContent(TokenizeResult result)
174 {
175 // get the tokens from a tokenize result received from the endpoint
176 return result.tokens;
177 }
178
179 protected virtual string DetokenizeContent(TokenizeRequest result)
180 {
181 // get content from a chat result received from the endpoint
182 return result.content;
183 }
184
185 protected virtual List<float> EmbeddingsContent(EmbeddingsResult result)
186 {
187 // get content from a chat result received from the endpoint
188 return result.embedding;
189 }
190
191 protected virtual Ret ConvertContent<Res, Ret>(string response, ContentCallback<Res, Ret> getContent = null)
192 {
193 // template function to convert the json received and get the content
194 if (response == null) return default;
195 response = response.Trim();
196 if (response.StartsWith("data: "))
197 {
198 string responseArray = "";
199 foreach (string responsePart in response.Replace("\n\n", "").Split("data: "))
200 {
201 if (responsePart == "") continue;
202 if (responseArray != "") responseArray += ",\n";
203 responseArray += responsePart;
204 }
205 response = $"{{\"data\": [{responseArray}]}}";
206 }
207 return getContent(JsonUtility.FromJson<Res>(response));
208 }
209
210 protected virtual void CancelRequestsLocal() {}
211
212 protected virtual void CancelRequestsRemote()
213 {
214 foreach (UnityWebRequest request in WIPRequests)
215 {
216 request.Abort();
217 }
218 WIPRequests.Clear();
219 }
220
224 // <summary>
225 public virtual void CancelRequests()
226 {
227 if (remote) CancelRequestsRemote();
228 else CancelRequestsLocal();
229 }
230
231 protected virtual async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
232 {
233 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
234 // this function has streaming functionality i.e. handles the answer while it is being received
235 while (!llm.failed && !llm.started) await Task.Yield();
236 string callResult = null;
237 switch (endpoint)
238 {
239 case "tokenize":
240 callResult = await llm.Tokenize(json);
241 break;
242 case "detokenize":
243 callResult = await llm.Detokenize(json);
244 break;
245 case "embeddings":
246 callResult = await llm.Embeddings(json);
247 break;
248 case "slots":
249 callResult = await llm.Slot(json);
250 break;
251 default:
252 LLMUnitySetup.LogError($"Unknown endpoint {endpoint}");
253 break;
254 }
255
256 Ret result = ConvertContent(callResult, getContent);
257 callback?.Invoke(result);
258 return result;
259 }
260
261 protected virtual async Task<Ret> PostRequestRemote<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
262 {
263 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
264 // this function has streaming functionality i.e. handles the answer while it is being received
265 if (endpoint == "slots")
266 {
267 LLMUnitySetup.LogError("Saving and loading is not currently supported in remote setting");
268 return default;
269 }
270
271 Ret result = default;
272 byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json);
273 UnityWebRequest request = null;
274 string error = null;
275 int tryNr = numRetries;
276
277 while (tryNr != 0)
278 {
279 using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
280 {
281 WIPRequests.Add(request);
282
283 request.method = "POST";
284 if (requestHeaders != null)
285 {
286 for (int i = 0; i < requestHeaders.Count; i++)
287 request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
288 }
289
290 // Start the request asynchronously
291 UnityWebRequestAsyncOperation asyncOperation = request.SendWebRequest();
292 await Task.Yield(); // Wait for the next frame so that asyncOperation is properly registered (especially if not in main thread)
293
294 float lastProgress = 0f;
295 // Continue updating progress until the request is completed
296 while (!asyncOperation.isDone)
297 {
298 float currentProgress = request.downloadProgress;
299 // Check if progress has changed
300 if (currentProgress != lastProgress && callback != null)
301 {
302 callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
303 lastProgress = currentProgress;
304 }
305 // Wait for the next frame
306 await Task.Yield();
307 }
308 WIPRequests.Remove(request);
309 if (request.result == UnityWebRequest.Result.Success)
310 {
311 result = ConvertContent(request.downloadHandler.text, getContent);
312 error = null;
313 break;
314 }
315 else
316 {
317 result = default;
318 error = request.error;
319 if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break;
320 }
321 }
322 tryNr--;
323 if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr));
324 }
325
326 if (error != null) LLMUnitySetup.LogError(error);
327 callback?.Invoke(result);
328 return result;
329 }
330
331 protected virtual async Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
332 {
333 if (remote) return await PostRequestRemote(json, endpoint, getContent, callback);
334 return await PostRequestLocal(json, endpoint, getContent, callback);
335 }
336
343 public virtual async Task<List<int>> Tokenize(string query, Callback<List<int>> callback = null)
344 {
345 // handle the tokenization of a message by the user
346 TokenizeRequest tokenizeRequest = new TokenizeRequest();
347 tokenizeRequest.content = query;
348 string json = JsonUtility.ToJson(tokenizeRequest);
349 return await PostRequest<TokenizeResult, List<int>>(json, "tokenize", TokenizeContent, callback);
350 }
351
358 public virtual async Task<string> Detokenize(List<int> tokens, Callback<string> callback = null)
359 {
360 // handle the detokenization of a message by the user
361 TokenizeResult tokenizeRequest = new TokenizeResult();
362 tokenizeRequest.tokens = tokens;
363 string json = JsonUtility.ToJson(tokenizeRequest);
364 return await PostRequest<TokenizeRequest, string>(json, "detokenize", DetokenizeContent, callback);
365 }
366
373 public virtual async Task<List<float>> Embeddings(string query, Callback<List<float>> callback = null)
374 {
375 // handle the tokenization of a message by the user
376 TokenizeRequest tokenizeRequest = new TokenizeRequest();
377 tokenizeRequest.content = query;
378 string json = JsonUtility.ToJson(tokenizeRequest);
379 return await PostRequest<EmbeddingsResult, List<float>>(json, "embeddings", EmbeddingsContent, callback);
380 }
381 }
382}
Class implementing calling of LLM functions (local and remote).
Definition LLMCaller.cs:17
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
Definition LLMCaller.cs:343
virtual bool IsValidLLM(LLM llmSet)
Checks if a LLM is valid for the LLMCaller.
Definition LLMCaller.cs:95
virtual async Task< List< float > > Embeddings(string query, Callback< List< float > > callback=null)
Computes the embeddings of the provided input.
Definition LLMCaller.cs:373
virtual void Awake()
The Unity Awake function that initializes the state before the application starts....
Definition LLMCaller.cs:53
int numRetries
number of retries to use for the LLM server requests (-1 = infinite)
Definition LLMCaller.cs:38
int port
port to use for the LLM server
Definition LLMCaller.cs:36
string host
host to use for the LLM server
Definition LLMCaller.cs:34
virtual void CancelRequests()
Cancel the ongoing requests e.g. Chat, Complete.
Definition LLMCaller.cs:225
virtual async Task< string > Detokenize(List< int > tokens, Callback< string > callback=null)
Detokenises the provided tokens to a string.
Definition LLMCaller.cs:358
bool remote
toggle to use remote LLM server or local LLM
Definition LLMCaller.cs:21
bool advancedOptions
toggle to show/hide advanced options in the GameObject
Definition LLMCaller.cs:19
string APIKey
allows to use a server with API key
Definition LLMCaller.cs:31
virtual bool IsAutoAssignableLLM(LLM llmSet)
Checks if a LLM can be auto-assigned if the LLM of the LLMCaller is null.
Definition LLMCaller.cs:105
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19
async Task< string > Slot(string json)
Allows to save / restore the state of a slot.
Definition LLM.cs:751
async Task< string > Detokenize(string json)
Detokenises the provided query.
Definition LLM.cs:679
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
Definition LLM.cs:44
async Task< string > Tokenize(string json)
Tokenises the provided query.
Definition LLM.cs:664
bool failed
Boolean set to true if the server has failed to start.
Definition LLM.cs:46
async Task< string > Embeddings(string json)
Computes the embeddings of the provided query.
Definition LLM.cs:694