3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading;
7using System.Threading.Tasks;
8using UnityEditor;
9using UnityEngine;
10using UnityEngine.Networking;
12namespace LLMUnity
14 [DefaultExecutionOrder(-2)]
19 public class LLMCharacter : MonoBehaviour
20 {
22 [HideInInspector] public bool advancedOptions = false;
24 [LocalRemote] public bool remote = false;
26 [Local] public LLM llm;
28 [Remote] public string host = "localhost";
30 [Remote] public int port = 13333;
32 [Remote] public int numRetries = 10;
34 [Remote] public string APIKey;
38 [LLM] public string save = "";
40 [LLM] public bool saveCache = false;
42 [LLM] public bool debugPrompt = false;
45 [Model] public bool stream = true;
47 [ModelAdvanced] public string grammar = null;
49 [ModelAdvanced] public bool cachePrompt = true;
51 [ModelAdvanced] public int slot = -1;
53 [ModelAdvanced] public int seed = 0;
58 [ModelAdvanced] public int numPredict = 256;
63 [ModelAdvanced, Float(0f, 2f)] public float temperature = 0.2f;
66 [ModelAdvanced, Int(-1, 100)] public int topK = 40;
71 [ModelAdvanced, Float(0f, 1f)] public float topP = 0.9f;
74 [ModelAdvanced, Float(0f, 1f)] public float minP = 0.05f;
77 [ModelAdvanced, Float(0f, 2f)] public float repeatPenalty = 1.1f;
80 [ModelAdvanced, Float(0f, 1f)] public float presencePenalty = 0f;
83 [ModelAdvanced, Float(0f, 1f)] public float frequencyPenalty = 0f;
86 [ModelAdvanced, Float(0f, 1f)] public float tfsZ = 1f;
88 [ModelAdvanced, Float(0f, 1f)] public float typicalP = 1f;
90 [ModelAdvanced, Int(0, 2048)] public int repeatLastN = 64;
92 [ModelAdvanced] public bool penalizeNl = true;
95 [ModelAdvanced] public string penaltyPrompt;
97 [ModelAdvanced, Int(0, 2)] public int mirostat = 0;
99 [ModelAdvanced, Float(0f, 10f)] public float mirostatTau = 5f;
101 [ModelAdvanced, Float(0f, 1f)] public float mirostatEta = 0.1f;
103 [ModelAdvanced, Int(0, 10)] public int nProbs = 0;
105 [ModelAdvanced] public bool ignoreEos = false;
108 public int nKeep = -1;
110 public List<string> stop = new List<string>();
113 public Dictionary<int, string> logitBias = null;
116 [Chat] public string playerName = "user";
118 [Chat] public string AIName = "assistant";
120 [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
122 public bool setNKeepToPrompt = true;
125 public List<ChatMessage> chat;
126 private SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
127 private string chatTemplate;
128 private ChatTemplate template = null;
129 public string grammarString;
130 private List<(string, string)> requestHeaders;
131 private List<UnityWebRequest> WIPRequests = new List<UnityWebRequest>();
143 public void Awake()
144 {
145 // Start the LLM server in a cross-platform way
146 if (!enabled) return;
148 requestHeaders = new List<(string, string)> { ("Content-Type", "application/json") };
149 if (!remote)
150 {
151 AssignLLM();
152 if (llm == null)
153 {
154 LLMUnitySetup.LogError($"No LLM assigned or detected for LLMCharacter {name}!");
155 return;
156 }
157 int slotFromServer = llm.Register(this);
158 if (slot == -1) slot = slotFromServer;
159 }
160 else
161 {
162 if (!String.IsNullOrEmpty(APIKey)) requestHeaders.Add(("Authorization", "Bearer " + APIKey));
163 }
165 InitGrammar();
166 InitHistory();
167 }
169 void OnValidate()
170 {
171 AssignLLM();
172 if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set");
173 }
175 void Reset()
176 {
177 AssignLLM();
178 }
180 void AssignLLM()
181 {
182 if (remote || llm != null) return;
185 if (existingLLMs.Length == 0) return;
187 SortBySceneAndHierarchy(existingLLMs);
188 llm = existingLLMs[0];
189 string msg = $"Assigning LLM {llm.name} to LLMCharacter {name}";
190 if (llm.gameObject.scene != gameObject.scene) msg += $" from scene {llm.gameObject.scene}";
191 LLMUnitySetup.Log(msg);
192 }
194 void SortBySceneAndHierarchy(LLM[] array)
195 {
196 for (int i = 0; i < array.Length - 1; i++)
197 {
198 bool swapped = false;
199 for (int j = 0; j < array.Length - i - 1; j++)
200 {
201 bool sameScene = array[j].gameObject.scene == array[j + 1].gameObject.scene;
202 bool swap = (
203 (!sameScene && array[j + 1].gameObject.scene == gameObject.scene) ||
204 (sameScene && array[j].transform.GetSiblingIndex() > array[j + 1].transform.GetSiblingIndex())
205 );
206 if (swap)
207 {
208 LLM temp = array[j];
209 array[j] = array[j + 1];
210 array[j + 1] = temp;
211 swapped = true;
212 }
213 }
214 if (!swapped) break;
215 }
216 }
218 protected void InitHistory()
219 {
220 InitPrompt();
221 _ = LoadHistory();
222 }
224 protected async Task LoadHistory()
225 {
226 if (save == "" || !File.Exists(GetJsonSavePath(save))) return;
227 await chatLock.WaitAsync(); // Acquire the lock
228 try
229 {
230 await Load(save);
231 }
232 finally
233 {
234 chatLock.Release(); // Release the lock
235 }
236 }
238 public virtual string GetSavePath(string filename)
239 {
240 return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
241 }
243 public virtual string GetJsonSavePath(string filename)
244 {
245 return GetSavePath(filename + ".json");
246 }
248 public virtual string GetCacheSavePath(string filename)
249 {
250 return GetSavePath(filename + ".cache");
251 }
253 private void InitPrompt(bool clearChat = true)
254 {
255 if (chat != null)
256 {
257 if (clearChat) chat.Clear();
258 }
259 else
260 {
261 chat = new List<ChatMessage>();
262 }
263 ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt };
264 if (chat.Count == 0)
265 {
266 chat.Add(promptMessage);
267 }
268 else
269 {
270 chat[0] = promptMessage;
271 }
272 }
279 public void SetPrompt(string newPrompt, bool clearChat = true)
280 {
282 nKeep = -1;
283 InitPrompt(clearChat);
284 }
286 private bool CheckTemplate()
287 {
288 if (template == null)
289 {
290 LLMUnitySetup.LogError("Template not set!");
291 return false;
292 }
293 return true;
294 }
296 private async Task<bool> InitNKeep()
297 {
298 if (setNKeepToPrompt && nKeep == -1)
299 {
300 if (!CheckTemplate()) return false;
301 string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
302 List<int> tokens = await Tokenize(systemPrompt);
303 if (tokens == null) return false;
304 SetNKeep(tokens);
305 }
306 return true;
307 }
309 private void InitGrammar()
310 {
311 if (grammar != null && grammar != "")
312 {
313 grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
314 }
315 }
317 private void SetNKeep(List<int> tokens)
318 {
319 // set the tokens to keep
320 nKeep = tokens.Count;
321 }
328 {
329 string llmTemplate;
330 if (remote)
331 {
333 }
334 else
335 {
337 }
338 if (llmTemplate != chatTemplate)
339 {
340 chatTemplate = llmTemplate;
341 template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate);
342 nKeep = -1;
343 }
344 }
350 public async void SetGrammar(string path)
351 {
353 if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
355 await LLMUnitySetup.AndroidExtractAsset(path, true);
356 grammar = path;
357 InitGrammar();
358 }
360 List<string> GetStopwords()
361 {
362 if (!CheckTemplate()) return null;
363 List<string> stopAll = new List<string>(template.GetStop(playerName, AIName));
364 if (stop != null) stopAll.AddRange(stop);
365 return stopAll;
366 }
368 ChatRequest GenerateRequest(string prompt)
369 {
370 // setup the request struct
372 if (debugPrompt) LLMUnitySetup.Log(prompt);
382 chatRequest.stop = GetStopwords();
400 return chatRequest;
401 }
403 public void AddMessage(string role, string content)
404 {
405 // add the question / answer to the chat list, update prompt
406 chat.Add(new ChatMessage { role = role, content = content });
407 }
409 public void AddPlayerMessage(string content)
410 {
411 AddMessage(playerName, content);
412 }
414 public void AddAIMessage(string content)
415 {
416 AddMessage(AIName, content);
417 }
419 protected string ChatContent(ChatResult result)
420 {
421 // get content from a chat result received from the endpoint
422 return result.content.Trim();
423 }
425 protected string MultiChatContent(MultiChatResult result)
426 {
427 // get content from a chat result received from the endpoint
428 string response = "";
429 foreach (ChatResult resultPart in result.data)
430 {
431 response += resultPart.content;
432 }
433 return response.Trim();
434 }
436 async Task<string> CompletionRequest(string json, Callback<string> callback = null)
437 {
438 string result = "";
439 if (stream)
440 {
441 result = await PostRequest<MultiChatResult, string>(json, "completion", MultiChatContent, callback);
442 }
443 else
444 {
445 result = await PostRequest<ChatResult, string>(json, "completion", ChatContent, callback);
446 }
447 return result;
448 }
450 protected string TemplateContent(TemplateResult result)
451 {
452 // get content from a char result received from the endpoint in open AI format
453 return result.template;
454 }
456 protected List<int> TokenizeContent(TokenizeResult result)
457 {
458 // get the tokens from a tokenize result received from the endpoint
459 return result.tokens;
460 }
462 protected string DetokenizeContent(TokenizeRequest result)
463 {
464 // get content from a chat result received from the endpoint
465 return result.content;
466 }
468 protected List<float> EmbeddingsContent(EmbeddingsResult result)
469 {
470 // get content from a chat result received from the endpoint
471 return result.embedding;
472 }
474 protected string SlotContent(SlotResult result)
475 {
476 // get the tokens from a tokenize result received from the endpoint
477 return result.filename;
478 }
492 {
493 // handle a chat message by the user
494 // call the callback function while the answer is received
495 // call the completionCallback function when the answer is fully received
497 if (!CheckTemplate()) return null;
498 if (!await InitNKeep()) return null;
500 string json;
501 await chatLock.WaitAsync();
502 try
503 {
504 AddPlayerMessage(query);
505 string prompt = template.ComputePrompt(chat, playerName, AIName);
506 json = JsonUtility.ToJson(GenerateRequest(prompt));
507 chat.RemoveAt(chat.Count - 1);
508 }
509 finally
510 {
511 chatLock.Release();
512 }
514 string result = await CompletionRequest(json, callback);
516 if (addToHistory && result != null)
517 {
518 await chatLock.WaitAsync();
519 try
520 {
521 AddPlayerMessage(query);
522 AddAIMessage(result);
523 }
524 finally
525 {
526 chatLock.Release();
527 }
528 if (save != "") _ = Save(save);
529 }
531 completionCallback?.Invoke();
532 return result;
533 }
545 {
546 // handle a completion request by the user
547 // call the callback function while the answer is received
548 // call the completionCallback function when the answer is fully received
551 string json = JsonUtility.ToJson(GenerateRequest(prompt));
552 string result = await CompletionRequest(json, callback);
553 completionCallback?.Invoke();
554 return result;
555 }
568 {
570 if (!CheckTemplate()) return;
571 if (!await InitNKeep()) return;
573 string prompt = template.ComputePrompt(chat, playerName, AIName);
574 ChatRequest request = GenerateRequest(prompt);
576 string json = JsonUtility.ToJson(request);
577 await CompletionRequest(json);
578 completionCallback?.Invoke();
579 }
586 {
587 return await PostRequest<TemplateResult, string>("{}", "template", TemplateContent);
588 }
596 public async Task<List<int>> Tokenize(string query, Callback<List<int>> callback = null)
597 {
598 // handle the tokenization of a message by the user
601 string json = JsonUtility.ToJson(tokenizeRequest);
602 return await PostRequest<TokenizeResult, List<int>>(json, "tokenize", TokenizeContent, callback);
603 }
612 {
613 // handle the detokenization of a message by the user
616 string json = JsonUtility.ToJson(tokenizeRequest);
617 return await PostRequest<TokenizeRequest, string>(json, "detokenize", DetokenizeContent, callback);
618 }
627 {
628 // handle the tokenization of a message by the user
631 string json = JsonUtility.ToJson(tokenizeRequest);
632 return await PostRequest<EmbeddingsResult, List<float>>(json, "embeddings", EmbeddingsContent, callback);
633 }
635 protected async Task<string> Slot(string filepath, string action)
636 {
641 string json = JsonUtility.ToJson(slotRequest);
642 return await PostRequest<SlotResult, string>(json, "slots", SlotContent);
643 }
650 public virtual async Task<string> Save(string filename)
651 {
652 string filepath = GetJsonSavePath(filename);
653 string dirname = Path.GetDirectoryName(filepath);
654 if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
655 string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
656 File.WriteAllText(filepath, json);
658 string cachepath = GetCacheSavePath(filename);
659 if (remote || !saveCache) return null;
660 string result = await Slot(cachepath, "save");
661 return result;
662 }
669 public virtual async Task<string> Load(string filename)
670 {
671 string filepath = GetJsonSavePath(filename);
672 if (!File.Exists(filepath))
673 {
674 LLMUnitySetup.LogError($"File {filepath} does not exist.");
675 return null;
676 }
677 string json = File.ReadAllText(filepath);
679 InitPrompt(true);
680 chat.AddRange(chatHistory);
681 LLMUnitySetup.Log($"Loaded {filepath}");
683 string cachepath = GetCacheSavePath(filename);
684 if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null;
685 string result = await Slot(cachepath, "restore");
686 return result;
687 }
689 protected Ret ConvertContent<Res, Ret>(string response, ContentCallback<Res, Ret> getContent = null)
690 {
691 // template function to convert the json received and get the content
692 if (response == null) return default;
693 response = response.Trim();
694 if (response.StartsWith("data: "))
695 {
696 string responseArray = "";
697 foreach (string responsePart in response.Replace("\n\n", "").Split("data: "))
698 {
699 if (responsePart == "") continue;
700 if (responseArray != "") responseArray += ",\n";
702 }
703 response = $"{{\"data\": [{responseArray}]}}";
704 }
705 return getContent(JsonUtility.FromJson<Res>(response));
706 }
708 protected void CancelRequestsLocal()
709 {
710 if (slot >= 0) llm.CancelRequest(slot);
711 }
713 protected void CancelRequestsRemote()
714 {
716 {
717 request.Abort();
718 }
719 WIPRequests.Clear();
720 }
725 // <summary>
726 public void CancelRequests()
727 {
728 if (remote) CancelRequestsRemote();
729 else CancelRequestsLocal();
730 }
732 protected async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
733 {
734 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
735 // this function has streaming functionality i.e. handles the answer while it is being received
736 string callResult = null;
737 bool callbackCalled = false;
738 while (!llm.failed && !llm.started) await Task.Yield();
739 switch (endpoint)
740 {
741 case "tokenize":
743 break;
744 case "detokenize":
746 break;
747 case "embeddings":
749 break;
750 case "slots":
752 break;
753 case "completion":
755 if (stream && callback != null)
756 {
757 if (typeof(Ret) == typeof(string))
758 {
760 {
762 };
763 }
764 else
765 {
766 LLMUnitySetup.LogError($"wrong callback type, should be string");
767 }
768 callbackCalled = true;
769 }
771 break;
772 default:
773 LLMUnitySetup.LogError($"Unknown endpoint {endpoint}");
774 break;
775 }
778 if (!callbackCalled) callback?.Invoke(result);
779 return result;
780 }
782 protected async Task<Ret> PostRequestRemote<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
783 {
784 // send a post request to the server and call the relevant callbacks to convert the received content and handle it
785 // this function has streaming functionality i.e. handles the answer while it is being received
786 if (endpoint == "slots")
787 {
788 LLMUnitySetup.LogError("Saving and loading is not currently supported in remote setting");
789 return default;
790 }
792 Ret result = default;
793 byte[] jsonToSend = new System.Text.UTF8Encoding().GetBytes(json);
795 string error = null;
796 int tryNr = numRetries;
798 while (tryNr != 0)
799 {
800 using (request = UnityWebRequest.Put($"{host}:{port}/{endpoint}", jsonToSend))
801 {
802 WIPRequests.Add(request);
804 request.method = "POST";
805 if (requestHeaders != null)
806 {
807 for (int i = 0; i < requestHeaders.Count; i++)
808 request.SetRequestHeader(requestHeaders[i].Item1, requestHeaders[i].Item2);
809 }
811 // Start the request asynchronously
812 var asyncOperation = request.SendWebRequest();
813 float lastProgress = 0f;
814 // Continue updating progress until the request is completed
815 while (!asyncOperation.isDone)
816 {
817 float currentProgress = request.downloadProgress;
818 // Check if progress has changed
819 if (currentProgress != lastProgress && callback != null)
820 {
821 try
822 {
823 callback?.Invoke(ConvertContent(request.downloadHandler.text, getContent));
824 }
825 catch (Exception) {}
827 }
828 // Wait for the next frame
829 await Task.Yield();
830 }
831 WIPRequests.Remove(request);
832 if (request.result == UnityWebRequest.Result.Success)
833 {
834 result = ConvertContent(request.downloadHandler.text, getContent);
835 error = null;
836 break;
837 }
838 else
839 {
840 result = default;
841 error = request.error;
842 if (request.responseCode == (int)System.Net.HttpStatusCode.Unauthorized) break;
843 }
844 }
845 tryNr--;
846 if (tryNr > 0) await Task.Delay(200 * (numRetries - tryNr));
847 }
849 if (error != null) LLMUnitySetup.LogError(error);
850 callback?.Invoke(result);
851 return result;
852 }
854 protected async Task<Ret> PostRequest<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
855 {
858 }
859 }
862 [Serializable]
863 public class ChatListWrapper
864 {
865 public List<ChatMessage> chat;
866 }
