LLM for Unity  v2.4.0
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMCaller.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.Threading.Tasks;
6using UnityEngine;
7using UnityEngine.Networking;
8
9namespace LLMUnity
10{
11 [DefaultExecutionOrder(-2)]
16 public class LLMCaller : MonoBehaviour
17 {
19 [HideInInspector] public bool advancedOptions = false;
21 [LocalRemote] public bool remote = false;
23 [Local, SerializeField] protected LLM _llm;
24 public LLM llm
25 {
26 get => _llm;//whatever
27 set => SetLLM(value);
28 }
29
31 [Remote] public string APIKey;
32
34 [Remote] public string host = "localhost";
36 [Remote] public int port = 13333;
38 [Remote] public int numRetries = 10;
39
40 protected LLM _prellm;
41 protected List<(string, string)> requestHeaders;
42 protected List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
43
53 public virtual void Awake()
54 {
55 // Start the LLM server in a cross-platform way
56 if (!enabled) return;
57
58 requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
59 if (!remote)
60 {
61 AssignLLM();
62 if (llm == null)
63 {
64 string error = $"No LLM assigned or detected for LLMCharacter {name}!";
65 LLMUnitySetup.LogError(error);
66 throw new Exception(error);
67 }
68 }
69 else
70 {
71 if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey));
72 }
73 }
74
79 protected virtual void SetLLM(LLM llmSet)
80 {
81 if (llmSet != null && !IsValidLLM(llmSet))
82 {
83 LLMUnitySetup.LogError(NotValidLLMError());
84 llmSet = null;
85 }
86 _llm = llmSet;
87 _prellm = _llm;
88 }
89
95 public virtual bool IsValidLLM(LLM llmSet)
96 {
97 return true;
98 }
99
105 public virtual bool IsAutoAssignableLLM(LLM llmSet)
106 {
107 return true;
108 }
109
110 protected virtual string NotValidLLMError()
111 {
112 return $"Can't set LLM {llm.name} to {name}";
113 }
114
115 protected virtual void OnValidate()
116 {
117 if (_llm != _prellm) SetLLM(_llm);
118 AssignLLM();
119 }
120
121 protected virtual void Reset()
122 {
123 AssignLLM();
124 }
125
126 protected virtual void AssignLLM()
127 {
128 if (remote || llm != null) return;
129
130 List<LLM> validLLMs = new List<LLM>();
131 foreach (LLM foundllm in FindObjectsOfType<LLM>())
132 {
133 if (IsValidLLM(foundllm) && IsAutoAssignableLLM(foundllm)) validLLMs.Add(foundllm);
134 }
135 if (validLLMs.Count == 0) return;
136
137 llm = SortLLMsByBestMatching(validLLMs.ToArray())[0];
138 string msg = $"Assigning LLM {llm.name} to {GetType()} {name}";
139 if (llm.gameObject.scene != gameObject.scene) msg += $" from scene {llm.gameObject.scene}";
140 LLMUnitySetup.Log(msg);
141 }
142
143 protected virtual LLM[] SortLLMsByBestMatching(LLM[] arrayIn)
144 {
145 LLM[] array = (LLM[])arrayIn.Clone();
146 for (int i = 0; i < array.Length - 1; i++)
147 {
148 bool swapped = false;
149 for (int j = 0; j < array.Length - i - 1; j++)
150 {
151 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
152 bool swap = (
153 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
154 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
155 );
156 if (swap)
157 {
158 LLM temp = array[j];
159 array[j] = array[j + 1];
160 array[j + 1] = temp;
161 swapped = true;
162 }
163 }
164 if (!swapped) break;
165 }
166 return array;
167 }
168
169 protected virtual List<int> TokenizeContent(TokenizeResult result)
170 {
171 // get the tokens from a tokenize result received from the endpoint
172 return result.tokens;
173 }
174
175 protected virtual string DetokenizeContent(TokenizeRequest result)
176 {
177 // get content from a chat result received from the endpoint
178 return result.content;
179 }
180
181 protected virtual List<float> EmbeddingsContent(EmbeddingsResult result)
182 {
183 // get content from a chat result received from the endpoint
184 return result.embedding;
185 }
186
187 protected virtual Ret ConvertContent<Res, Ret>(string response, ContentCallback<Res, Ret> getContent = null)
188 {
189 // template function to convert the json received and get the content
190 if (response == null) return default;
191 response = response.Trim();
192 if (response.StartsWith("data: "))
193 {
194 string responseArray = "";
195 foreach (string responsePart in response.Replace("\n\n", "").Split("data: "))
196 {
197 if (responsePart == "") continue;
198 if (responseArray != "") responseArray += ",\n";
199 responseArray += responsePart;
200 }
201 response = $"{{\"data\": [{responseArray}]}}";
202 }
203 return getContent(JsonUtility.FromJson<Res>(response));
204 }
205
206 protected virtual void CancelRequestsLocal() {}
207
208 protected virtual void CancelRequestsRemote()
209 {
210 foreach (UnityWebRequest request in WIPRequests)
211 {
212 request.Abort();
213 }
214 WIPRequests.Clear();
215 }
216
220 // <summary>
221 public virtual void CancelRequests()
222 {
223 if (remote) CancelRequestsRemote();
224 else CancelRequestsLocal();
225 }
226
227 protected virtual async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
228 {
229 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
230 // this function has streaming functionality i.e. handles the answer while it is being received
231 while (!llm.failed && !llm.started) await Task.Yield();
232 string callResult = null;
233 switch (endpoint)
234 {
235 case "tokenize":
236 callResult = await llm.Tokenize(json);
237 break;
238 case "detokenize":
239 callResult = await llm.Detokenize(json);
240 break;
241 case "embeddings":
242 callResult = await llm.Embeddings(json);
243 break;
244 case "slots":
245 callResult = await llm.Slot(json);
246 break;
247 default:
248 LLMUnitySetup.LogError($"Unknown endpoint {endpoint}");
249 break;
250 }
251
252 Ret result = ConvertContent(callResult, getContent);
253 callback?.Invoke(result);
254 return result;
255 }
256
257 protected virtual async Task<Ret> PostRequestRemote<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
258 {
259 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
260 // this function has streaming functionality i.e. handles the answer while it is being received
261 if (endpoint == "slots")
262 {
263 LLMUnitySetup.LogError("Saving and loading is not currently supported in remote setting");
264 return default;
265 }
266
267 Ret result = default;
268 byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json);
269 UnityWebRequest request = null;
270 string error = null;
271 int tryNr = numRetries;
272
273 while (tryNr != 0)
274 {
275 using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
276 {
277 WIPRequests.Add(request);
278
279 request.method = "POST";
280 if (requestHeaders != null)
281 {
282 for (int i = 0; i < requestHeaders.Count; i++)
283 request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
284 }
285
286 // Start the request asynchronously
287 var asyncOperation = request.SendWebRequest();
288 float lastProgress = 0f;
289 // Continue updating progress until the request is completed
290 while (!asyncOperation.isDone)
291 {
292 float currentProgress = request.downloadProgress;
293 // Check if progress has changed
294 if (currentProgress != lastProgress && callback != null)
295 {
296 callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
297 lastProgress = currentProgress;
298 }
299 // Wait for the next frame
300 await Task.Yield();
301 }
302 WIPRequests.Remove(request);
303 if (request.result == UnityWebRequest.Result.Success)
304 {
305 result = ConvertContent(request.downloadHandler.text, getContent);
306 error = null;
307 break;
308 }
309 else
310 {
311 result = default;
312 error = request.error;
313 if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break;
314 }
315 }
316 tryNr--;
317 if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr));
318 }
319
320 if (error != null) LLMUnitySetup.LogError(error);
321 callback?.Invoke(result);
322 return result;
323 }
324
325 protected virtual async Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
326 {
327 if (remote) return await PostRequestRemote(json, endpoint, getContent, callback);
328 return await PostRequestLocal(json, endpoint, getContent, callback);
329 }
330
337 public virtual async Task<List<int>> Tokenize(string query, Callback<List<int>> callback = null)
338 {
339 // handle the tokenization of a message by the user
340 TokenizeRequest tokenizeRequest = new TokenizeRequest();
341 tokenizeRequest.content = query;
342 string json = JsonUtility.ToJson(tokenizeRequest);
343 return await PostRequest<TokenizeResult, List<int>>(json, "tokenize", TokenizeContent, callback);
344 }
345
352 public virtual async Task<string> Detokenize(List<int> tokens, Callback<string> callback = null)
353 {
354 // handle the detokenization of a message by the user
355 TokenizeResult tokenizeRequest = new TokenizeResult();
356 tokenizeRequest.tokens = tokens;
357 string json = JsonUtility.ToJson(tokenizeRequest);
358 return await PostRequest<TokenizeRequest, string>(json, "detokenize", DetokenizeContent, callback);
359 }
360
367 public virtual async Task<List<float>> Embeddings(string query, Callback<List<float>> callback = null)
368 {
369 // handle the tokenization of a message by the user
370 TokenizeRequest tokenizeRequest = new TokenizeRequest();
371 tokenizeRequest.content = query;
372 string json = JsonUtility.ToJson(tokenizeRequest);
373 return await PostRequest<EmbeddingsResult, List<float>>(json, "embeddings", EmbeddingsContent, callback);
374 }
375 }
376}
Class implementing calling of LLM functions (local and remote).
Definition LLMCaller.cs:17
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
Definition LLMCaller.cs:337
virtual bool IsValidLLM(LLM llmSet)
Checks if a LLM is valid for the LLMCaller.
Definition LLMCaller.cs:95
virtual async Task< List< float > > Embeddings(string query, Callback< List< float > > callback=null)
Computes the embeddings of the provided input.
Definition LLMCaller.cs:367
virtual void Awake()
The Unity Awake function that initializes the state before the application starts....
Definition LLMCaller.cs:53
int numRetries
number of retries to use for the LLM server requests (-1 = infinite)
Definition LLMCaller.cs:38
int port
port to use for the LLM server
Definition LLMCaller.cs:36
string host
host to use for the LLM server
Definition LLMCaller.cs:34
virtual void CancelRequests()
Cancel the ongoing requests e.g. Chat, Complete.
Definition LLMCaller.cs:221
virtual async Task< string > Detokenize(List< int > tokens, Callback< string > callback=null)
Detokenises the provided tokens to a string.
Definition LLMCaller.cs:352
bool remote
toggle to use remote LLM server or local LLM
Definition LLMCaller.cs:21
bool advancedOptions
toggle to show/hide advanced options in the GameObject
Definition LLMCaller.cs:19
string APIKey
allows to use a server with API key
Definition LLMCaller.cs:31
virtual bool IsAutoAssignableLLM(LLM llmSet)
Checks if a LLM can be auto-assigned if the LLM of the LLMCaller is null.
Definition LLMCaller.cs:105
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19
async Task< string > Slot(string json)
Allows to save / restore the state of a slot.
Definition LLM.cs:748
async Task< string > Detokenize(string json)
Detokenises the provided query.
Definition LLM.cs:676
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
Definition LLM.cs:44
async Task< string > Tokenize(string json)
Tokenises the provided query.
Definition LLM.cs:661
bool failed
Boolean set to true if the server has failed to start.
Definition LLM.cs:46
async Task< string > Embeddings(string json)
Computes the embeddings of the provided query.
Definition LLM.cs:691