4using System.Collections.Generic;
5using System.Threading.Tasks;
7using UnityEngine.Networking;
11 [DefaultExecutionOrder(-2)]
16 public class LLMCaller : MonoBehaviour
19 [Tooltip(
"show/hide advanced options in the GameObject")]
20 [HideInInspector]
public bool advancedOptions =
false;
22 [Tooltip(
"use remote LLM server")]
23 [LocalRemote]
public bool remote =
false;
25 [Tooltip(
"LLM GameObject to use")]
26 [Local, SerializeField]
protected LLM _llm;
33 [Tooltip(
"API key for the remote server")]
34 [Remote]
public string APIKey;
36 [Tooltip(
"host of the remote LLM server")]
37 [Remote]
public string host =
"localhost";
39 [Tooltip(
"port of the remote LLM server")]
40 [Remote]
public int port = 13333;
42 [Tooltip(
"number of retries to use for the remote LLM server requests (-1 = infinite)")]
43 [Remote]
public int numRetries = 10;
45 protected LLM _prellm;
46 protected List<(string, string)> requestHeaders;
47 protected List<UnityWebRequest> WIPRequests =
new List<UnityWebRequest>();
58 public virtual void Awake()
63 requestHeaders =
new List<(string, string)> { (
"Content-Type",
"application/json") };
69 string error = $
"No LLM assigned or detected for LLMCharacter {name}!";
70 LLMUnitySetup.LogError(error);
71 throw new Exception(error);
76 if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add((
"Authorization",
"Bearer " + APIKey));
84 protected virtual void SetLLM(LLM llmSet)
86 if (llmSet !=
null && !IsValidLLM(llmSet))
88 LLMUnitySetup.LogError(NotValidLLMError());
100 public virtual bool IsValidLLM(LLM llmSet)
110 public virtual bool IsAutoAssignableLLM(LLM llmSet)
115 protected virtual string NotValidLLMError()
117 return $
"Can't set LLM {llm.name} to {name}";
120 protected virtual void OnValidate()
122 if (_llm != _prellm) SetLLM(_llm);
126 protected virtual void Reset()
131 protected virtual void AssignLLM()
133 if (remote || llm !=
null)
return;
135 List<LLM> validLLMs =
new List<LLM>();
136#if UNITY_6000_0_OR_NEWER
137 foreach (LLM foundllm
in FindObjectsByType(typeof(LLM), FindObjectsSortMode.None))
139 foreach (LLM foundllm
in FindObjectsOfType<LLM>())
142 if (IsValidLLM(foundllm) && IsAutoAssignableLLM(foundllm)) validLLMs.Add(foundllm);
144 if (validLLMs.Count == 0)
return;
146 llm = SortLLMsByBestMatching(validLLMs.ToArray())[0];
147 string msg = $
"Assigning LLM {llm.name} to {GetType()} {name}";
148 if (llm.gameObject.scene != gameObject.scene) msg += $
" from scene {llm.gameObject.scene}";
149 LLMUnitySetup.Log(msg);
152 protected virtual LLM[] SortLLMsByBestMatching(LLM[] arrayIn)
154 LLM[] array = (LLM[])arrayIn.Clone();
155 for (
int i = 0; i < array.Length - 1; i++)
157 bool swapped =
false;
158 for (
int j = 0; j < array.Length - i - 1; j++)
160 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
162 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
163 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
168 array[j] = array[j + 1];
178 protected virtual List<int> TokenizeContent(TokenizeResult result)
181 return result.tokens;
184 protected virtual string DetokenizeContent(TokenizeRequest result)
187 return result.content;
190 protected virtual List<float> EmbeddingsContent(EmbeddingsResult result)
193 return result.embedding;
196 protected virtual Ret ConvertContent<Res, Ret>(
string response, ContentCallback<Res, Ret> getContent =
null)
199 if (response ==
null)
return default;
200 response = response.Trim();
201 if (response.StartsWith(
"data: "))
203 string responseArray =
"";
204 foreach (
string responsePart
in response.Replace(
"\n\n",
"").Split(
"data: "))
206 if (responsePart ==
"")
continue;
207 if (responseArray !=
"") responseArray +=
",\n";
208 responseArray += responsePart;
210 response = $
"{{\"data\": [{responseArray}]}}";
212 return getContent(JsonUtility.FromJson<Res>(response));
215 protected virtual void CancelRequestsLocal() {}
217 protected virtual void CancelRequestsRemote()
219 foreach (UnityWebRequest request
in WIPRequests)
230 public virtual void CancelRequests()
232 if (remote) CancelRequestsRemote();
233 else CancelRequestsLocal();
236 protected virtual async Task<Ret> PostRequestLocal<Res, Ret>(
string json,
string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback =
null)
240 while (!llm.failed && !llm.started) await Task.Yield();
241 string callResult =
null;
245 callResult = await llm.Tokenize(json);
248 callResult = await llm.Detokenize(json);
251 callResult = await llm.Embeddings(json);
254 callResult = await llm.Slot(json);
257 LLMUnitySetup.LogError($
"Unknown endpoint {endpoint}");
261 Ret result = ConvertContent(callResult, getContent);
262 callback?.Invoke(result);
266 protected virtual async Task<Ret> PostRequestRemote<Res, Ret>(
string json,
string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback =
null)
270 if (endpoint ==
"slots")
272 LLMUnitySetup.LogError(
"Saving and loading is not currently supported in remote setting");
276 Ret result =
default;
277 byte[] jsonToSend =
new System.Text.UTF8Encoding().GetBytes(json);
278 UnityWebRequest request =
null;
280 int tryNr = numRetries;
284 using (request = UnityWebRequest.Put($
"{host}{(port != 0 ? $":{port}
" : "")}/{endpoint}", jsonToSend))
286 WIPRequests.Add(request);
288 request.method =
"POST";
289 if (requestHeaders !=
null)
291 for (
int i = 0; i < requestHeaders.Count; i++)
292 request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
296 UnityWebRequestAsyncOperation asyncOperation = request.SendWebRequest();
299 float lastProgress = 0f;
301 while (!asyncOperation.isDone)
303 float currentProgress = request.downloadProgress;
305 if (currentProgress != lastProgress && callback !=
null)
307 callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
308 lastProgress = currentProgress;
313 WIPRequests.Remove(request);
314 if (request.result == UnityWebRequest.Result.Success)
316 result = ConvertContent(request.downloadHandler.text, getContent);
323 error = request.error;
324 if (request.responseCode == (
int)System.Net.HttpStatusCode.Unauthorized)
break;
328 if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr));
331 if (error !=
null) LLMUnitySetup.LogError(error);
332 callback?.Invoke(result);
336 protected virtual async Task<Ret> PostRequest<Res, Ret>(
string json,
string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback =
null)
338 if (remote)
return await PostRequestRemote(json, endpoint, getContent, callback);
339 return await PostRequestLocal(json, endpoint, getContent, callback);
348 public virtual async Task<List<int>> Tokenize(
string query, Callback<List<int>> callback =
null)
351 TokenizeRequest tokenizeRequest =
new TokenizeRequest();
352 tokenizeRequest.content = query;
353 string json = JsonUtility.ToJson(tokenizeRequest);
354 return await PostRequest<TokenizeResult, List<int>>(json,
"tokenize", TokenizeContent, callback);
363 public virtual async Task<string> Detokenize(List<int> tokens, Callback<string> callback =
null)
366 TokenizeResult tokenizeRequest =
new TokenizeResult();
367 tokenizeRequest.tokens = tokens;
368 string json = JsonUtility.ToJson(tokenizeRequest);
369 return await PostRequest<TokenizeRequest, string>(json,
"detokenize", DetokenizeContent, callback);
378 public virtual async Task<List<float>> Embeddings(
string query, Callback<List<float>> callback =
null)
381 TokenizeRequest tokenizeRequest =
new TokenizeRequest();
382 tokenizeRequest.content = query;
383 string json = JsonUtility.ToJson(tokenizeRequest);
384 return await PostRequest<EmbeddingsResult, List<float>>(json,
"embeddings", EmbeddingsContent, callback);