18 public class LLM : MonoBehaviour
23 [LocalRemote]
public bool remote =
false;
25 [Remote]
public int port = 13333;
40 [DynamicRange(
"minContextLength",
"maxContextLength",
false), Model]
public int contextSize = 8192;
44 [TextArea(5, 10), ChatAdvanced]
public string basePrompt =
"";
46 public bool started {
get;
protected set; } =
false;
48 public bool failed {
get;
protected set; } =
false;
56 [ModelAdvanced]
public string model =
"";
61 [ModelAdvanced]
public string lora =
"";
71 private string SSLCert =
"";
72 public string SSLCertPath =
"";
75 private string SSLKey =
"";
76 public string SSLKeyPath =
"";
79 public int minContextLength = 0;
80 public int maxContextLength = 0;
82 IntPtr LLMObject = IntPtr.Zero;
83 List<LLMCharacter> clients =
new List<LLMCharacter>();
85 StreamWrapper logStreamWrapper =
null;
86 Thread llmThread =
null;
87 List<StreamWrapper> streamWrappers =
new List<StreamWrapper>();
89 private readonly
object startLock =
new object();
90 static readonly
object staticLock =
new object();
91 public LoraManager loraManager =
new LoraManager();
93 string loraWeightsPre =
"";
117 if (!enabled)
return;
127 string arguments = GetLlamaccpArguments();
128 if (arguments ==
null)
133 await Task.Run(() => StartLLMServer(arguments));
139 public async Task WaitUntilReady()
141 while (!
started) await Task.Yield();
144 public static async Task<bool> WaitUntilModelSetup(Callback<float> downloadProgressCallback =
null)
146 if (downloadProgressCallback !=
null) LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback);
151 public static string GetLLMManagerAsset(
string path)
154 if (!EditorApplication.isPlaying)
return GetLLMManagerAssetEditor(path);
156 return GetLLMManagerAssetRuntime(path);
159 public static string GetLLMManagerAssetEditor(
string path)
162 if (
string.IsNullOrEmpty(path))
return path;
164 ModelEntry modelEntry = LLMManager.Get(path);
165 if (modelEntry !=
null)
return modelEntry.filename;
167 string assetPath = LLMUnitySetup.GetAssetPath(path);
168 string basePath = LLMUnitySetup.GetAssetPath();
169 if (File.Exists(assetPath))
171 if (LLMUnitySetup.IsSubPath(assetPath, basePath))
return LLMUnitySetup.RelativePath(assetPath, basePath);
174 if (!File.Exists(assetPath))
176 LLMUnitySetup.LogError($
"Model {path} was not found.");
180 string errorMessage = $
"The model {path} was loaded locally. You can include it in the build in one of these ways:";
181 errorMessage += $
"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path";
182 errorMessage += $
"\n-Load the model with the model manager inside the LLM GameObject and use its filename";
183 LLMUnitySetup.LogWarning(errorMessage);
188 public static string GetLLMManagerAssetRuntime(
string path)
191 if (
string.IsNullOrEmpty(path))
return path;
193 string managerPath = LLMManager.GetAssetPath(path);
194 if (!
string.IsNullOrEmpty(managerPath) && File.Exists(managerPath))
return managerPath;
196 string assetPath = LLMUnitySetup.GetAssetPath(path);
197 if (File.Exists(assetPath))
return assetPath;
210 model = GetLLMManagerAsset(path);
211 if (!
string.IsNullOrEmpty(
model))
214 if (modelEntry ==
null) modelEntry =
new ModelEntry(GetLLMManagerAssetRuntime(
model));
217 maxContextLength = modelEntry.contextLength;
219 if (
contextSize == 0 && modelEntry.contextLength > 32768)
221 LLMUnitySetup.LogWarning($
"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM");
225 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
235 public void SetLora(
string path,
float weight = 1)
248 public void AddLora(
string path,
float weight = 1)
251 loraManager.Add(path, weight);
263 loraManager.Remove(path);
284 loraManager.SetWeight(path, weight);
295 foreach (KeyValuePair<string, float> entry
in loraToWeight) loraManager.SetWeight(entry.Key, entry.Value);
300 public void UpdateLoras()
305 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
313 public void SetTemplate(
string templateName,
bool setDirty =
true)
318 if (setDirty && !EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
324 string ReadFileContents(
string path)
326 if (String.IsNullOrEmpty(path))
return "";
327 else if (!File.Exists(path))
332 return File.ReadAllText(path);
344 SSLCert = ReadFileContents(path);
354 SSLKey = ReadFileContents(path);
366 protected virtual string GetLlamaccpArguments()
369 if ((SSLCert !=
"" && SSLKey ==
"") || (SSLCert ==
"" && SSLKey !=
""))
371 LLMUnitySetup.LogError($
"Both SSL certificate and key need to be provided!");
377 LLMUnitySetup.LogError(
"No model file provided!");
380 string modelPath = GetLLMManagerAssetRuntime(
model);
381 if (!File.Exists(modelPath))
383 LLMUnitySetup.LogError($
"File {modelPath} not found!");
388 string loraArgument =
"";
389 foreach (
string lora in loraManager.GetLoras())
391 string loraPath = GetLLMManagerAssetRuntime(
lora);
392 if (!File.Exists(loraPath))
394 LLMUnitySetup.LogError($
"File {loraPath} not found!");
397 loraArgument += $
" --lora \"{loraPath}\"";
401 if (Application.platform == RuntimePlatform.Android &&
numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores();
403 int slots = GetNumClients();
404 string arguments = $
"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
407 arguments += $
" --port {port} --host 0.0.0.0";
408 if (!String.IsNullOrEmpty(
APIKey)) arguments += $
" --api-key {APIKey}";
410 if (numThreadsToUse > 0) arguments += $
" -t {numThreadsToUse}";
411 arguments += loraArgument;
412 arguments += $
" -ngl {numGPULayers}";
413 if (LLMUnitySetup.FullLlamaLib &&
flashAttention) arguments += $
" --flash-attn";
416 string serverCommand;
417 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) serverCommand =
"undreamai_server.exe";
418 else serverCommand =
"./undreamai_server";
419 serverCommand +=
" " + arguments;
420 serverCommand += $
" --template \"{chatTemplate}\"";
421 if (
remote && SSLCert !=
"" && SSLKey !=
"") serverCommand += $
" --ssl-cert-file {SSLCertPath} --ssl-key-file {SSLKeyPath}";
422 LLMUnitySetup.Log($
"Deploy server command: {serverCommand}");
426 private void SetupLogging()
428 logStreamWrapper = ConstructStreamWrapper(LLMUnitySetup.LogWarning,
true);
429 llmlib?.Logging(logStreamWrapper.GetStringWrapper());
432 private void StopLogging()
434 if (logStreamWrapper ==
null)
return;
435 llmlib?.StopLogging();
436 DestroyStreamWrapper(logStreamWrapper);
439 private void StartLLMServer(
string arguments)
445 foreach (
string arch
in LLMLib.PossibleArchitectures(useGPU))
451 InitService(arguments);
452 LLMUnitySetup.Log($
"Using architecture: {arch}");
455 catch (LLMException e)
460 catch (DestroyException)
466 error = $
"{e.GetType()}: {e.Message}";
468 LLMUnitySetup.Log($
"Tried architecture: {arch}, " + error);
472 LLMUnitySetup.LogError(
"LLM service couldn't be created");
476 CallWithLock(StartService);
477 LLMUnitySetup.Log(
"LLM service created");
480 private void InitLib(
string arch)
482 llmlib =
new LLMLib(arch);
483 CheckLLMStatus(
false);
486 void CallWithLock(EmptyCallback fn)
490 if (llmlib ==
null)
throw new DestroyException();
495 private void InitService(
string arguments)
499 if (
debug) CallWithLock(SetupLogging);
500 CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
501 CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject,
chatTemplate));
504 if (SSLCert !=
"" && SSLKey !=
"")
506 LLMUnitySetup.Log(
"Using SSL");
507 CallWithLock(() => llmlib.LLM_SetSSL(LLMObject, SSLCert, SSLKey));
509 CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
511 CallWithLock(() => CheckLLMStatus(
false));
515 private void StartService()
517 llmThread =
new Thread(() => llmlib.LLM_Start(LLMObject));
519 while (!llmlib.LLM_Started(LLMObject)) {}
532 clients.Add(llmCharacter);
533 int index = clients.IndexOf(llmCharacter);
538 protected int GetNumClients()
544 public delegate
void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
545 public delegate
void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
546 public delegate
void LLMReplyCallback(IntPtr LLMObject,
string json_data, IntPtr stringWrapper);
549 StreamWrapper ConstructStreamWrapper(Callback<string> streamCallback =
null,
bool clearOnUpdate =
false)
551 StreamWrapper streamWrapper =
new StreamWrapper(llmlib, streamCallback, clearOnUpdate);
552 streamWrappers.Add(streamWrapper);
553 return streamWrapper;
556 void DestroyStreamWrapper(StreamWrapper streamWrapper)
558 streamWrappers.Remove(streamWrapper);
559 streamWrapper.Destroy();
566 foreach (StreamWrapper streamWrapper
in streamWrappers) streamWrapper.Update();
572 if (
failed) error =
"LLM service couldn't be created";
573 else if (!
started) error =
"LLM service not started";
577 throw new Exception(error);
581 void AssertNotStarted()
585 string error =
"This method can't be called when the LLM has started";
586 LLMUnitySetup.LogError(error);
587 throw new Exception(error);
591 void CheckLLMStatus(
bool log =
true)
593 if (llmlib ==
null) {
return; }
594 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
595 int status = llmlib.LLM_Status(LLMObject, stringWrapper);
596 string result = llmlib.GetStringWrapperResult(stringWrapper);
597 llmlib.StringWrapper_Delete(stringWrapper);
598 string message = $
"LLM {status}: {result}";
601 if (log) LLMUnitySetup.LogError(message);
602 throw new LLMException(message, status);
606 if (log) LLMUnitySetup.LogWarning(message);
610 async Task<string> LLMNoInputReply(LLMNoInputReplyCallback callback)
613 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
614 await Task.Run(() => callback(LLMObject, stringWrapper));
615 string result = llmlib?.GetStringWrapperResult(stringWrapper);
616 llmlib?.StringWrapper_Delete(stringWrapper);
621 async Task<string> LLMReply(LLMReplyCallback callback,
string json)
624 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
625 await Task.Run(() => callback(LLMObject, json, stringWrapper));
626 string result = llmlib?.GetStringWrapperResult(stringWrapper);
627 llmlib?.StringWrapper_Delete(stringWrapper);
640 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
642 llmlib.LLM_Tokenize(LLMObject, jsonData, strWrapper);
644 return await LLMReply(callback, json);
655 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
657 llmlib.LLM_Detokenize(LLMObject, jsonData, strWrapper);
659 return await LLMReply(callback, json);
670 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
672 llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
674 return await LLMReply(callback, json);
683 LoraWeightRequestList loraWeightRequest =
new LoraWeightRequestList();
684 loraWeightRequest.loraWeights =
new List<LoraWeightRequest>();
685 float[] weights = loraManager.GetWeights();
686 for (
int i = 0; i < weights.Length; i++)
688 loraWeightRequest.loraWeights.Add(
new LoraWeightRequest() {
id = i, scale = weights[i] });
691 string json = JsonUtility.ToJson(loraWeightRequest);
692 int startIndex = json.IndexOf(
"[");
693 int endIndex = json.LastIndexOf(
"]") + 1;
694 json = json.Substring(startIndex, endIndex - startIndex);
696 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
697 llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper);
698 llmlib.StringWrapper_Delete(stringWrapper);
708 LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
710 llmlib.LLM_LoraList(LLMObject, strWrapper);
712 string json = await LLMNoInputReply(callback);
713 if (String.IsNullOrEmpty(json))
return null;
714 LoraWeightResultList loraRequest = JsonUtility.FromJson<LoraWeightResultList>(
"{\"loraWeights\": " + json +
"}");
715 return loraRequest.loraWeights;
723 public async Task<string>
Slot(
string json)
726 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
728 llmlib.LLM_Slot(LLMObject, jsonData, strWrapper);
730 return await LLMReply(callback, json);
739 public async Task<string>
Completion(
string json, Callback<string> streamCallback =
null)
742 if (streamCallback ==
null) streamCallback = (
string s) => {};
743 StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback);
744 await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper()));
746 streamWrapper.Update();
747 string result = streamWrapper.GetString();
748 DestroyStreamWrapper(streamWrapper);
753 public async Task SetBasePrompt(
string base_prompt)
756 SystemPromptRequest request =
new SystemPromptRequest() { system_prompt = base_prompt, prompt =
" ", n_predict = 0 };
757 await
Completion(JsonUtility.ToJson(request));
767 llmlib?.LLM_Cancel(LLMObject, id_slot);
783 if (LLMObject != IntPtr.Zero)
785 llmlib.LLM_Stop(LLMObject);
786 if (
remote) llmlib.LLM_StopServer(LLMObject);
789 llmlib.LLM_Delete(LLMObject);
790 LLMObject = IntPtr.Zero;