21 [LocalRemote]
public bool remote =
false;
23 [Local, SerializeField]
protected LLM _llm;
34 [Remote]
public string host =
"localhost";
36 [Remote]
public int port = 13333;
40 protected LLM _prellm;
41 protected List<(string, string)> requestHeaders;
42 protected List<UnityWebRequest> WIPRequests =
new List<UnityWebRequest>();
58 requestHeaders =
new List<(string, string)> { (
"Content-Type",
"application/json") };
64 string error = $
"No LLM assigned or detected for LLMCharacter {name}!";
66 throw new Exception(error);
71 if (!String.IsNullOrEmpty(
APIKey)) requestHeaders.Add((
"Authorization",
"Bearer " +
APIKey));
79 protected virtual void SetLLM(
LLM llmSet)
110 protected virtual string NotValidLLMError()
112 return $
"Can't set LLM {llm.name} to {name}";
115 protected virtual void OnValidate()
117 if (_llm != _prellm) SetLLM(_llm);
121 protected virtual void Reset()
126 protected virtual void AssignLLM()
128 if (
remote || llm !=
null)
return;
130 List<LLM> validLLMs =
new List<LLM>();
131 foreach (LLM foundllm
in FindObjectsOfType<LLM>())
135 if (validLLMs.Count == 0)
return;
137 llm = SortLLMsByBestMatching(validLLMs.ToArray())[0];
138 string msg = $
"Assigning LLM {llm.name} to {GetType()} {name}";
139 if (llm.gameObject.scene != gameObject.scene) msg += $
" from scene {llm.gameObject.scene}";
140 LLMUnitySetup.Log(msg);
143 protected virtual LLM[] SortLLMsByBestMatching(LLM[] arrayIn)
145 LLM[] array = (LLM[])arrayIn.Clone();
146 for (
int i = 0; i < array.Length - 1; i++)
148 bool swapped =
false;
149 for (
int j = 0; j < array.Length - i - 1; j++)
151 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
153 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
154 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
159 array[j] = array[j + 1];
169 protected virtual List<int> TokenizeContent(TokenizeResult result)
172 return result.tokens;
175 protected virtual string DetokenizeContent(TokenizeRequest result)
178 return result.content;
181 protected virtual List<float> EmbeddingsContent(EmbeddingsResult result)
184 return result.embedding;
187 protected virtual Ret ConvertContent<Res, Ret>(
string response, ContentCallback<Res, Ret> getContent =
null)
190 if (response ==
null)
return default;
191 response = response.Trim();
192 if (response.StartsWith(
"data: "))
194 string responseArray =
"";
195 foreach (
string responsePart
in response.Replace(
"\n\n",
"").Split(
"data: "))
197 if (responsePart ==
"")
continue;
198 if (responseArray !=
"") responseArray +=
",\n";
199 responseArray += responsePart;
201 response = $
"{{\"data\": [{responseArray}]}}";
203 return getContent(JsonUtility.FromJson<Res>(response));
206 protected virtual void CancelRequestsLocal() {}
208 protected virtual void CancelRequestsRemote()
210 foreach (UnityWebRequest request
in WIPRequests)
223 if (
remote) CancelRequestsRemote();
224 else CancelRequestsLocal();
227 protected virtual async Task<Ret> PostRequestLocal<Res, Ret>(
string json,
string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback =
null)
232 string callResult =
null;
236 callResult = await llm.
Tokenize(json);
245 callResult = await llm.
Slot(json);
252 Ret result = ConvertContent(callResult, getContent);
253 callback?.Invoke(result);
257 protected virtual async Task<Ret> PostRequestRemote<Res, Ret>(
string json,
string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback =
null)
261 if (endpoint ==
"slots")
263 LLMUnitySetup.LogError(
"Saving and loading is not currently supported in remote setting");
267 Ret result =
default;
268 byte[] jsonToSend =
new System.Text.UTF8Encoding().GetBytes(json);
269 UnityWebRequest request =
null;
275 using (request = UnityWebRequest.Put($
"{host}:{port}/{endpoint}", jsonToSend))
277 WIPRequests.Add(request);
279 request.method =
"POST";
280 if (requestHeaders !=
null)
282 for (
int i = 0; i < requestHeaders.Count; i++)
283 request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
287 var asyncOperation = request.SendWebRequest();
288 float lastProgress = 0f;
290 while (!asyncOperation.isDone)
292 float currentProgress = request.downloadProgress;
294 if (currentProgress != lastProgress && callback !=
null)
296 callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
297 lastProgress = currentProgress;
302 WIPRequests.Remove(request);
303 if (request.result == UnityWebRequest.Result.Success)
305 result = ConvertContent(request.downloadHandler.text, getContent);
312 error = request.error;
313 if (request.responseCode == (
int)System.Net.HttpStatusCode.Unauthorized)
break;
317 if (tryNr > 0) await Task.Delay(200 * (
numRetries - tryNr));
320 if (error !=
null) LLMUnitySetup.LogError(error);
321 callback?.Invoke(result);
325 protected virtual async Task<Ret> PostRequest<Res, Ret>(
string json,
string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback =
null)
327 if (
remote)
return await PostRequestRemote(json, endpoint, getContent, callback);
328 return await PostRequestLocal(json, endpoint, getContent, callback);
337 public virtual async Task<List<int>>
Tokenize(
string query, Callback<List<int>> callback =
null)
340 TokenizeRequest tokenizeRequest =
new TokenizeRequest();
341 tokenizeRequest.content = query;
342 string json = JsonUtility.ToJson(tokenizeRequest);
343 return await PostRequest<TokenizeResult, List<int>>(json,
"tokenize", TokenizeContent, callback);
352 public virtual async Task<string>
Detokenize(List<int> tokens, Callback<string> callback =
null)
355 TokenizeResult tokenizeRequest =
new TokenizeResult();
356 tokenizeRequest.tokens = tokens;
357 string json = JsonUtility.ToJson(tokenizeRequest);
358 return await PostRequest<TokenizeRequest, string>(json,
"detokenize", DetokenizeContent, callback);
367 public virtual async Task<List<float>>
Embeddings(
string query, Callback<List<float>> callback =
null)
370 TokenizeRequest tokenizeRequest =
new TokenizeRequest();
371 tokenizeRequest.content = query;
372 string json = JsonUtility.ToJson(tokenizeRequest);
373 return await PostRequest<EmbeddingsResult, List<float>>(json,
"embeddings", EmbeddingsContent, callback);