3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading;
7using System.Threading.Tasks;
8using UnityEditor;
9using UnityEngine;
11namespace LLMUnity
13 [DefaultExecutionOrder(-1)]
18 public class LLM : MonoBehaviour
19 {
21 [HideInInspector] public bool advancedOptions = false;
23 [LocalRemote] public bool remote = false;
25 [Remote] public int port = 13333;
27 [LLM] public int numThreads = -1;
31 [LLM] public int numGPULayers = 0;
33 [LLM] public bool debug = false;
35 [LLMAdvanced] public int parallelPrompts = -1;
37 [LLMAdvanced] public bool dontDestroyOnLoad = true;
40 [DynamicRange("minContextLength", "maxContextLength", false), Model] public int contextSize = 8192;
42 [ModelAdvanced] public int batchSize = 512;
44 [TextArea(5, 10), ChatAdvanced] public string basePrompt = "";
46 public bool started { get; protected set; } = false;
48 public bool failed { get; protected set; } = false;
50 public static bool modelSetupFailed { get; protected set; } = false;
52 public static bool modelSetupComplete { get; protected set; } = false;
56 [ModelAdvanced] public string model = "";
58 [ModelAdvanced] public string chatTemplate = ChatTemplate.DefaultTemplate;
61 [ModelAdvanced] public string lora = "";
63 [ModelAdvanced] public string loraWeights = "";
65 [ModelExtras] public bool flashAttention = false;
68 public string APIKey;
69 // SSL certificate
70 [SerializeField]
71 private string SSLCert = "";
72 public string SSLCertPath = "";
73 // SSL key
74 [SerializeField]
75 private string SSLKey = "";
76 public string SSLKeyPath = "";
79 public int minContextLength = 0;
80 public int maxContextLength = 0;
82 IntPtr LLMObject = IntPtr.Zero;
83 List<LLMCharacter> clients = new List<LLMCharacter>();
84 LLMLib llmlib;
85 StreamWrapper logStreamWrapper = null;
86 Thread llmThread = null;
87 List<StreamWrapper> streamWrappers = new List<StreamWrapper>();
88 public LLMManager llmManager = new LLMManager();
89 private readonly object startLock = new object();
90 static readonly object staticLock = new object();
91 public LoraManager loraManager = new LoraManager();
92 string loraPre = "";
93 string loraWeightsPre = "";
97 public LLM()
98 {
99 LLMManager.Register(this);
100 }
102 void OnValidate()
103 {
104 if (lora != loraPre || loraWeights != loraWeightsPre)
105 {
106 loraManager.FromStrings(lora, loraWeights);
107 (loraPre, loraWeightsPre) = (lora, loraWeights);
108 }
109 }
115 public async void Awake()
116 {
117 if (!enabled) return;
119 modelSetupFailed = !await LLMManager.Setup();
121 modelSetupComplete = true;
123 {
124 failed = true;
125 return;
126 }
127 string arguments = GetLlamaccpArguments();
128 if (arguments == null)
129 {
130 failed = true;
131 return;
132 }
133 await Task.Run(() => StartLLMServer(arguments));
134 if (!started) return;
135 if (dontDestroyOnLoad) DontDestroyOnLoad(transform.root.gameObject);
136 if (basePrompt != "") await SetBasePrompt(basePrompt);
137 }
139 public async Task WaitUntilReady()
140 {
141 while (!started) await Task.Yield();
142 }
144 public static async Task<bool> WaitUntilModelSetup(Callback<float> downloadProgressCallback = null)
145 {
146 if (downloadProgressCallback != null) LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback);
147 while (!modelSetupComplete) await Task.Yield();
148 return !modelSetupFailed;
149 }
151 public static string GetLLMManagerAsset(string path)
152 {
154 if (!EditorApplication.isPlaying) return GetLLMManagerAssetEditor(path);
156 return GetLLMManagerAssetRuntime(path);
157 }
159 public static string GetLLMManagerAssetEditor(string path)
160 {
161 // empty
162 if (string.IsNullOrEmpty(path)) return path;
163 // LLMManager - return location the file will be stored in StreamingAssets
164 ModelEntry modelEntry = LLMManager.Get(path);
165 if (modelEntry != null) return modelEntry.filename;
166 // StreamingAssets - return relative location within StreamingAssets
167 string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
168 string basePath = LLMUnitySetup.GetAssetPath();
169 if (File.Exists(assetPath))
170 {
171 if (LLMUnitySetup.IsSubPath(assetPath, basePath)) return LLMUnitySetup.RelativePath(assetPath, basePath);
172 }
173 // full path
174 if (!File.Exists(assetPath))
175 {
176 LLMUnitySetup.LogError($"Model {path} was not found.");
177 }
178 else
179 {
180 string errorMessage = $"The model {path} was loaded locally. You can include it in the build in one of these ways:";
181 errorMessage += $"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path";
182 errorMessage += $"\n-Load the model with the model manager inside the LLM GameObject and use its filename";
183 LLMUnitySetup.LogWarning(errorMessage);
184 }
185 return path;
186 }
188 public static string GetLLMManagerAssetRuntime(string path)
189 {
190 // empty
191 if (string.IsNullOrEmpty(path)) return path;
192 // LLMManager
193 string managerPath = LLMManager.GetAssetPath(path);
194 if (!string.IsNullOrEmpty(managerPath) && File.Exists(managerPath)) return managerPath;
195 // StreamingAssets
196 string assetPath = LLMUnitySetup.GetAssetPath(path);
197 if (File.Exists(assetPath)) return assetPath;
198 // give up
199 return path;
200 }
208 public void SetModel(string path)
209 {
210 model = GetLLMManagerAsset(path);
211 if (!string.IsNullOrEmpty(model))
212 {
213 ModelEntry modelEntry = LLMManager.Get(model);
214 if (modelEntry == null) modelEntry = new ModelEntry(GetLLMManagerAssetRuntime(model));
215 SetTemplate(modelEntry.chatTemplate);
217 maxContextLength = modelEntry.contextLength;
218 if (contextSize > maxContextLength) contextSize = maxContextLength;
219 if (contextSize == 0 && modelEntry.contextLength > 32768)
220 {
221 LLMUnitySetup.LogWarning($"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM");
222 }
223 }
225 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
227 }
235 public void SetLora(string path, float weight = 1)
236 {
237 AssertNotStarted();
238 loraManager.Clear();
239 AddLora(path, weight);
240 }
248 public void AddLora(string path, float weight = 1)
249 {
250 AssertNotStarted();
251 loraManager.Add(path, weight);
252 UpdateLoras();
253 }
260 public void RemoveLora(string path)
261 {
262 AssertNotStarted();
263 loraManager.Remove(path);
264 UpdateLoras();
265 }
270 public void RemoveLoras()
271 {
272 AssertNotStarted();
273 loraManager.Clear();
274 UpdateLoras();
275 }
282 public void SetLoraWeight(string path, float weight)
283 {
284 loraManager.SetWeight(path, weight);
285 UpdateLoras();
286 if (started) ApplyLoras();
287 }
293 public void SetLoraWeights(Dictionary<string, float> loraToWeight)
294 {
295 foreach (KeyValuePair<string, float> entry in loraToWeight) loraManager.SetWeight(entry.Key, entry.Value);
296 UpdateLoras();
297 if (started) ApplyLoras();
298 }
300 public void UpdateLoras()
301 {
302 (lora, loraWeights) = loraManager.ToStrings();
303 (loraPre, loraWeightsPre) = (lora, loraWeights);
305 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(this);
307 }
313 public void SetTemplate(string templateName, bool setDirty = true)
314 {
315 chatTemplate = templateName;
316 if (started) llmlib?.LLM_SetTemplate(LLMObject, chatTemplate);
318 if (setDirty && !EditorApplication.isPlaying) EditorUtility.SetDirty(this);
320 }
324 string ReadFileContents(string path)
325 {
326 if (String.IsNullOrEmpty(path)) return "";
327 else if (!File.Exists(path))
328 {
329 LLMUnitySetup.LogError($"File {path} not found!");
330 return "";
331 }
332 return File.ReadAllText(path);
333 }
341 public void SetSSLCert(string path)
342 {
343 SSLCertPath = path;
344 SSLCert = ReadFileContents(path);
345 }
351 public void SetSSLKey(string path)
352 {
353 SSLKeyPath = path;
354 SSLKey = ReadFileContents(path);
355 }
361 public string GetTemplate()
362 {
363 return chatTemplate;
364 }
366 protected virtual string GetLlamaccpArguments()
367 {
368 // Start the LLM server in a cross-platform way
369 if ((SSLCert != "" && SSLKey == "") || (SSLCert == "" && SSLKey != ""))
370 {
371 LLMUnitySetup.LogError($"Both SSL certificate and key need to be provided!");
372 return null;
373 }
375 if (model == "")
376 {
377 LLMUnitySetup.LogError("No model file provided!");
378 return null;
379 }
380 string modelPath = GetLLMManagerAssetRuntime(model);
381 if (!File.Exists(modelPath))
382 {
383 LLMUnitySetup.LogError($"File {modelPath} not found!");
384 return null;
385 }
387 loraManager.FromStrings(lora, loraWeights);
388 string loraArgument = "";
389 foreach (string lora in loraManager.GetLoras())
390 {
391 string loraPath = GetLLMManagerAssetRuntime(lora);
392 if (!File.Exists(loraPath))
393 {
394 LLMUnitySetup.LogError($"File {loraPath} not found!");
395 return null;
396 }
397 loraArgument += $" --lora \"{loraPath}\"";
398 }
400 int numThreadsToUse = numThreads;
401 if (Application.platform == RuntimePlatform.Android && numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores();
403 int slots = GetNumClients();
404 string arguments = $"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
405 if (remote)
406 {
407 arguments += $" --port {port} --host";
408 if (!String.IsNullOrEmpty(APIKey)) arguments += $" --api-key {APIKey}";
409 }
410 if (numThreadsToUse > 0) arguments += $" -t {numThreadsToUse}";
411 arguments += loraArgument;
412 arguments += $" -ngl {numGPULayers}";
413 if (LLMUnitySetup.FullLlamaLib && flashAttention) arguments += $" --flash-attn";
415 // the following is the equivalent for running from command line
416 string serverCommand;
417 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) serverCommand = "undreamai_server.exe";
418 else serverCommand = "./undreamai_server";
419 serverCommand += " " + arguments;
420 serverCommand += $" --template \"{chatTemplate}\"";
421 if (remote && SSLCert != "" && SSLKey != "") serverCommand += $" --ssl-cert-file {SSLCertPath} --ssl-key-file {SSLKeyPath}";
422 LLMUnitySetup.Log($"Deploy server command: {serverCommand}");
423 return arguments;
424 }
426 private void SetupLogging()
427 {
428 logStreamWrapper = ConstructStreamWrapper(LLMUnitySetup.LogWarning, true);
429 llmlib?.Logging(logStreamWrapper.GetStringWrapper());
430 }
432 private void StopLogging()
433 {
434 if (logStreamWrapper == null) return;
435 llmlib?.StopLogging();
436 DestroyStreamWrapper(logStreamWrapper);
437 }
439 private void StartLLMServer(string arguments)
440 {
441 started = false;
442 failed = false;
443 bool useGPU = numGPULayers > 0;
445 foreach (string arch in LLMLib.PossibleArchitectures(useGPU))
446 {
447 string error;
448 try
449 {
450 InitLib(arch);
451 InitService(arguments);
452 LLMUnitySetup.Log($"Using architecture: {arch}");
453 break;
454 }
455 catch (LLMException e)
456 {
457 error = e.Message;
458 Destroy();
459 }
460 catch (DestroyException)
461 {
462 break;
463 }
464 catch (Exception e)
465 {
466 error = $"{e.GetType()}: {e.Message}";
467 }
468 LLMUnitySetup.Log($"Tried architecture: {arch}, " + error);
469 }
470 if (llmlib == null)
471 {
472 LLMUnitySetup.LogError("LLM service couldn't be created");
473 failed = true;
474 return;
475 }
476 CallWithLock(StartService);
477 LLMUnitySetup.Log("LLM service created");
478 }
480 private void InitLib(string arch)
481 {
482 llmlib = new LLMLib(arch);
483 CheckLLMStatus(false);
484 }
486 void CallWithLock(EmptyCallback fn)
487 {
488 lock (startLock)
489 {
490 if (llmlib == null) throw new DestroyException();
491 fn();
492 }
493 }
495 private void InitService(string arguments)
496 {
497 lock (staticLock)
498 {
499 if (debug) CallWithLock(SetupLogging);
500 CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
501 CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject, chatTemplate));
502 if (remote)
503 {
504 if (SSLCert != "" && SSLKey != "")
505 {
506 LLMUnitySetup.Log("Using SSL");
507 CallWithLock(() => llmlib.LLM_SetSSL(LLMObject, SSLCert, SSLKey));
508 }
509 CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
510 }
511 CallWithLock(() => CheckLLMStatus(false));
512 }
513 }
515 private void StartService()
516 {
517 llmThread = new Thread(() => llmlib.LLM_Start(LLMObject));
518 llmThread.Start();
519 while (!llmlib.LLM_Started(LLMObject)) {}
520 ApplyLoras();
521 started = true;
522 }
530 public int Register(LLMCharacter llmCharacter)
531 {
532 clients.Add(llmCharacter);
533 int index = clients.IndexOf(llmCharacter);
534 if (parallelPrompts != -1) return index % parallelPrompts;
535 return index;
536 }
538 protected int GetNumClients()
539 {
540 return Math.Max(parallelPrompts == -1 ? clients.Count : parallelPrompts, 1);
541 }
544 public delegate void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
545 public delegate void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
546 public delegate void LLMReplyCallback(IntPtr LLMObject, string json_data, IntPtr stringWrapper);
549 StreamWrapper ConstructStreamWrapper(Callback<string> streamCallback = null, bool clearOnUpdate = false)
550 {
551 StreamWrapper streamWrapper = new StreamWrapper(llmlib, streamCallback, clearOnUpdate);
552 streamWrappers.Add(streamWrapper);
553 return streamWrapper;
554 }
556 void DestroyStreamWrapper(StreamWrapper streamWrapper)
557 {
558 streamWrappers.Remove(streamWrapper);
559 streamWrapper.Destroy();
560 }
564 public void Update()
565 {
566 foreach (StreamWrapper streamWrapper in streamWrappers) streamWrapper.Update();
567 }
569 void AssertStarted()
570 {
571 string error = null;
572 if (failed) error = "LLM service couldn't be created";
573 else if (!started) error = "LLM service not started";
574 if (error != null)
575 {
576 LLMUnitySetup.LogError(error);
577 throw new Exception(error);
578 }
579 }
581 void AssertNotStarted()
582 {
583 if (started)
584 {
585 string error = "This method can't be called when the LLM has started";
586 LLMUnitySetup.LogError(error);
587 throw new Exception(error);
588 }
589 }
591 void CheckLLMStatus(bool log = true)
592 {
593 if (llmlib == null) { return; }
594 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
595 int status = llmlib.LLM_Status(LLMObject, stringWrapper);
596 string result = llmlib.GetStringWrapperResult(stringWrapper);
597 llmlib.StringWrapper_Delete(stringWrapper);
598 string message = $"LLM {status}: {result}";
599 if (status > 0)
600 {
601 if (log) LLMUnitySetup.LogError(message);
602 throw new LLMException(message, status);
603 }
604 else if (status < 0)
605 {
606 if (log) LLMUnitySetup.LogWarning(message);
607 }
608 }
610 async Task<string> LLMNoInputReply(LLMNoInputReplyCallback callback)
611 {
612 AssertStarted();
613 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
614 await Task.Run(() => callback(LLMObject, stringWrapper));
615 string result = llmlib?.GetStringWrapperResult(stringWrapper);
616 llmlib?.StringWrapper_Delete(stringWrapper);
617 CheckLLMStatus();
618 return result;
619 }
621 async Task<string> LLMReply(LLMReplyCallback callback, string json)
622 {
623 AssertStarted();
624 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
625 await Task.Run(() => callback(LLMObject, json, stringWrapper));
626 string result = llmlib?.GetStringWrapperResult(stringWrapper);
627 llmlib?.StringWrapper_Delete(stringWrapper);
628 CheckLLMStatus();
629 return result;
630 }
637 public async Task<string> Tokenize(string json)
638 {
639 AssertStarted();
640 LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
641 {
642 llmlib.LLM_Tokenize(LLMObject, jsonData, strWrapper);
643 };
644 return await LLMReply(callback, json);
645 }
652 public async Task<string> Detokenize(string json)
653 {
654 AssertStarted();
655 LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
656 {
657 llmlib.LLM_Detokenize(LLMObject, jsonData, strWrapper);
658 };
659 return await LLMReply(callback, json);
660 }
667 public async Task<string> Embeddings(string json)
668 {
669 AssertStarted();
670 LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
671 {
672 llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
673 };
674 return await LLMReply(callback, json);
675 }
681 public void ApplyLoras()
682 {
683 LoraWeightRequestList loraWeightRequest = new LoraWeightRequestList();
684 loraWeightRequest.loraWeights = new List<LoraWeightRequest>();
685 float[] weights = loraManager.GetWeights();
686 for (int i = 0; i < weights.Length; i++)
687 {
688 loraWeightRequest.loraWeights.Add(new LoraWeightRequest() { id = i, scale = weights[i] });
689 }
691 string json = JsonUtility.ToJson(loraWeightRequest);
692 int startIndex = json.IndexOf("[");
693 int endIndex = json.LastIndexOf("]") + 1;
694 json = json.Substring(startIndex, endIndex - startIndex);
696 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
697 llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper);
698 llmlib.StringWrapper_Delete(stringWrapper);
699 }
705 public async Task<List<LoraWeightResult>> ListLoras()
706 {
707 AssertStarted();
708 LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
709 {
710 llmlib.LLM_LoraList(LLMObject, strWrapper);
711 };
712 string json = await LLMNoInputReply(callback);
713 if (String.IsNullOrEmpty(json)) return null;
714 LoraWeightResultList loraRequest = JsonUtility.FromJson<LoraWeightResultList>("{\"loraWeights\": " + json + "}");
715 return loraRequest.loraWeights;
716 }
723 public async Task<string> Slot(string json)
724 {
725 AssertStarted();
726 LLMReplyCallback callback = (IntPtr LLMObject, string jsonData, IntPtr strWrapper) =>
727 {
728 llmlib.LLM_Slot(LLMObject, jsonData, strWrapper);
729 };
730 return await LLMReply(callback, json);
731 }
739 public async Task<string> Completion(string json, Callback<string> streamCallback = null)
740 {
741 AssertStarted();
742 if (streamCallback == null) streamCallback = (string s) => {};
743 StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback);
744 await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper()));
745 if (!started) return null;
746 streamWrapper.Update();
747 string result = streamWrapper.GetString();
748 DestroyStreamWrapper(streamWrapper);
749 CheckLLMStatus();
750 return result;
751 }
753 public async Task SetBasePrompt(string base_prompt)
754 {
755 AssertStarted();
756 SystemPromptRequest request = new SystemPromptRequest() { system_prompt = base_prompt, prompt = " ", n_predict = 0 };
757 await Completion(JsonUtility.ToJson(request));
758 }
764 public void CancelRequest(int id_slot)
765 {
766 AssertStarted();
767 llmlib?.LLM_Cancel(LLMObject, id_slot);
768 CheckLLMStatus();
769 }
774 public void Destroy()
775 {
776 lock (staticLock)
777 lock (startLock)
778 {
779 try
780 {
781 if (llmlib != null)
782 {
783 if (LLMObject != IntPtr.Zero)
784 {
785 llmlib.LLM_Stop(LLMObject);
786 if (remote) llmlib.LLM_StopServer(LLMObject);
787 StopLogging();
788 llmThread?.Join();
789 llmlib.LLM_Delete(LLMObject);
790 LLMObject = IntPtr.Zero;
791 }
792 llmlib.Destroy();
793 llmlib = null;
794 }
795 started = false;
796 failed = false;
797 }
798 catch (Exception e)
799 {
800 LLMUnitySetup.LogError(e.Message);
801 }
802 }
803 }
809 public void OnDestroy()
810 {
811 Destroy();
812 LLMManager.Unregister(this);
813 }
814 }
