4using System.Collections.Generic;
7using System.Threading.Tasks;
13 [DefaultExecutionOrder(-1)]
18 public class LLM : MonoBehaviour
23 [LocalRemote]
public bool remote =
false;
25 [Remote]
public int port = 13333;
40 [DynamicRange(
"minContextLength",
"maxContextLength",
false), Model]
public int contextSize = 8192;
44 [TextArea(5, 10), ChatAdvanced]
public string basePrompt =
"";
46 public bool started {
get;
protected set; } =
false;
48 public bool failed {
get;
protected set; } =
false;
56 [ModelAdvanced]
public string model =
"";
61 [ModelAdvanced]
public string lora =
"";
71 private string SSLCert =
"";
72 public string SSLCertPath =
"";
75 private string SSLKey =
"";
76 public string SSLKeyPath =
"";
79 public int minContextLength = 0;
80 public int maxContextLength = 0;
82 IntPtr LLMObject = IntPtr.Zero;
83 List<LLMCaller> clients =
new List<LLMCaller>();
85 StreamWrapper logStreamWrapper =
null;
86 Thread llmThread =
null;
87 List<StreamWrapper> streamWrappers =
new List<StreamWrapper>();
89 private readonly
object startLock =
new object();
90 static readonly
object staticLock =
new object();
93 string loraWeightsPre =
"";
94 public bool embeddingsOnly =
false;
95 public int embeddingLength = 0;
118 if (!enabled)
return;
128 string arguments = GetLlamaccpArguments();
129 if (arguments ==
null)
134 await Task.Run(() => StartLLMServer(arguments));
145 while (!
started) await Task.Yield();
154 if (downloadProgressCallback !=
null)
LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback);
160 public static string GetLLMManagerAsset(
string path)
163 if (!EditorApplication.isPlaying)
return GetLLMManagerAssetEditor(path);
165 return GetLLMManagerAssetRuntime(path);
168 public static string GetLLMManagerAssetEditor(
string path)
171 if (
string.IsNullOrEmpty(path))
return path;
173 ModelEntry modelEntry = LLMManager.Get(path);
174 if (modelEntry !=
null)
return modelEntry.filename;
176 string assetPath = LLMUnitySetup.GetAssetPath(path);
177 string basePath = LLMUnitySetup.GetAssetPath();
178 if (File.Exists(assetPath))
180 if (LLMUnitySetup.IsSubPath(assetPath, basePath))
return LLMUnitySetup.RelativePath(assetPath, basePath);
183 if (!File.Exists(assetPath))
185 LLMUnitySetup.LogError($
"Model {path} was not found.");
189 string errorMessage = $
"The model {path} was loaded locally. You can include it in the build in one of these ways:";
190 errorMessage += $
"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path";
191 errorMessage += $
"\n-Load the model with the model manager inside the LLM GameObject and use its filename";
192 LLMUnitySetup.LogWarning(errorMessage);
197 public static string GetLLMManagerAssetRuntime(
string path)
200 if (
string.IsNullOrEmpty(path))
return path;
202 string managerPath = LLMManager.GetAssetPath(path);
203 if (!
string.IsNullOrEmpty(managerPath) && File.Exists(managerPath))
return managerPath;
205 string assetPath = LLMUnitySetup.GetAssetPath(path);
206 if (File.Exists(assetPath))
return assetPath;
221 model = GetLLMManagerAsset(path);
222 if (!
string.IsNullOrEmpty(
model))
225 if (modelEntry ==
null) modelEntry =
new ModelEntry(GetLLMManagerAssetRuntime(
model));
228 maxContextLength = modelEntry.contextLength;
230 SetEmbeddings(modelEntry.embeddingLength, modelEntry.embeddingOnly);
231 if (
contextSize == 0 && modelEntry.contextLength > 32768)
233 LLMUnitySetup.LogWarning($
"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM");
237 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
247 public void SetLora(
string path,
float weight = 1)
260 public void AddLora(
string path,
float weight = 1)
263 loraManager.
Add(path, weight);
307 foreach (KeyValuePair<string, float> entry
in loraToWeight) loraManager.
SetWeight(entry.Key, entry.Value);
312 public void UpdateLoras()
317 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
325 public void SetTemplate(
string templateName,
bool setDirty =
true)
330 if (setDirty && !EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
341 this.embeddingsOnly = embeddingsOnly;
342 this.embeddingLength = embeddingLength;
344 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
350 string ReadFileContents(
string path)
352 if (String.IsNullOrEmpty(path))
return "";
353 else if (!File.Exists(path))
358 return File.ReadAllText(path);
370 SSLCert = ReadFileContents(path);
380 SSLKey = ReadFileContents(path);
392 protected virtual string GetLlamaccpArguments()
395 if ((SSLCert !=
"" && SSLKey ==
"") || (SSLCert ==
"" && SSLKey !=
""))
397 LLMUnitySetup.LogError($
"Both SSL certificate and key need to be provided!");
403 LLMUnitySetup.LogError(
"No model file provided!");
406 string modelPath = GetLLMManagerAssetRuntime(
model);
407 if (!File.Exists(modelPath))
409 LLMUnitySetup.LogError($
"File {modelPath} not found!");
414 string loraArgument =
"";
417 string loraPath = GetLLMManagerAssetRuntime(
lora);
418 if (!File.Exists(loraPath))
420 LLMUnitySetup.LogError($
"File {loraPath} not found!");
423 loraArgument += $
" --lora \"{loraPath}\"";
427 if (Application.platform == RuntimePlatform.Android &&
numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores();
429 int slots = GetNumClients();
430 string arguments = $
"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
431 if (embeddingsOnly) arguments +=
" --embedding";
432 if (numThreadsToUse > 0) arguments += $
" -t {numThreadsToUse}";
433 arguments += loraArgument;
434 arguments += $
" -ngl {numGPULayers}";
435 if (LLMUnitySetup.FullLlamaLib &&
flashAttention) arguments += $
" --flash-attn";
438 arguments += $
" --port {port} --host 0.0.0.0";
439 if (!String.IsNullOrEmpty(
APIKey)) arguments += $
" --api-key {APIKey}";
443 string serverCommand;
444 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) serverCommand =
"undreamai_server.exe";
445 else serverCommand =
"./undreamai_server";
446 serverCommand +=
" " + arguments;
447 serverCommand += $
" --template \"{chatTemplate}\"";
448 if (
remote && SSLCert !=
"" && SSLKey !=
"") serverCommand += $
" --ssl-cert-file {SSLCertPath} --ssl-key-file {SSLKeyPath}";
449 LLMUnitySetup.Log($
"Deploy server command: {serverCommand}");
453 private void SetupLogging()
455 logStreamWrapper = ConstructStreamWrapper(LLMUnitySetup.LogWarning,
true);
456 llmlib?.Logging(logStreamWrapper.GetStringWrapper());
459 private void StopLogging()
461 if (logStreamWrapper ==
null)
return;
462 llmlib?.StopLogging();
463 DestroyStreamWrapper(logStreamWrapper);
466 private void StartLLMServer(
string arguments)
472 foreach (
string arch
in LLMLib.PossibleArchitectures(useGPU))
478 InitService(arguments);
479 LLMUnitySetup.Log($
"Using architecture: {arch}");
482 catch (LLMException e)
487 catch (DestroyException)
493 error = $
"{e.GetType()}: {e.Message}";
495 LLMUnitySetup.Log($
"Tried architecture: {arch}, " + error);
499 LLMUnitySetup.LogError(
"LLM service couldn't be created");
503 CallWithLock(StartService);
504 LLMUnitySetup.Log(
"LLM service created");
507 private void InitLib(
string arch)
509 llmlib =
new LLMLib(arch);
510 CheckLLMStatus(
false);
513 void CallWithLock(EmptyCallback fn)
517 if (llmlib ==
null)
throw new DestroyException();
522 private void InitService(
string arguments)
526 if (
debug) CallWithLock(SetupLogging);
527 CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
528 CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject,
chatTemplate));
531 if (SSLCert !=
"" && SSLKey !=
"")
533 LLMUnitySetup.Log(
"Using SSL");
534 CallWithLock(() => llmlib.LLM_SetSSL(LLMObject, SSLCert, SSLKey));
536 CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
538 CallWithLock(() => CheckLLMStatus(
false));
542 private void StartService()
544 llmThread =
new Thread(() => llmlib.LLM_Start(LLMObject));
546 while (!llmlib.LLM_Started(LLMObject)) {}
559 clients.Add(llmCaller);
560 int index = clients.IndexOf(llmCaller);
565 protected int GetNumClients()
571 public delegate
void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
572 public delegate
void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
573 public delegate
void LLMReplyCallback(IntPtr LLMObject,
string json_data, IntPtr stringWrapper);
576 StreamWrapper ConstructStreamWrapper(Callback<string> streamCallback =
null,
bool clearOnUpdate =
false)
578 StreamWrapper streamWrapper =
new StreamWrapper(llmlib, streamCallback, clearOnUpdate);
579 streamWrappers.Add(streamWrapper);
580 return streamWrapper;
583 void DestroyStreamWrapper(StreamWrapper streamWrapper)
585 streamWrappers.Remove(streamWrapper);
586 streamWrapper.Destroy();
593 foreach (StreamWrapper streamWrapper
in streamWrappers) streamWrapper.Update();
599 if (
failed) error =
"LLM service couldn't be created";
600 else if (!
started) error =
"LLM service not started";
604 throw new Exception(error);
608 void AssertNotStarted()
612 string error =
"This method can't be called when the LLM has started";
613 LLMUnitySetup.LogError(error);
614 throw new Exception(error);
618 void CheckLLMStatus(
bool log =
true)
620 if (llmlib ==
null) {
return; }
621 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
622 int status = llmlib.LLM_Status(LLMObject, stringWrapper);
623 string result = llmlib.GetStringWrapperResult(stringWrapper);
624 llmlib.StringWrapper_Delete(stringWrapper);
625 string message = $
"LLM {status}: {result}";
628 if (log) LLMUnitySetup.LogError(message);
629 throw new LLMException(message, status);
633 if (log) LLMUnitySetup.LogWarning(message);
637 async Task<string> LLMNoInputReply(LLMNoInputReplyCallback callback)
640 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
641 await Task.Run(() => callback(LLMObject, stringWrapper));
642 string result = llmlib?.GetStringWrapperResult(stringWrapper);
643 llmlib?.StringWrapper_Delete(stringWrapper);
648 async Task<string> LLMReply(LLMReplyCallback callback,
string json)
651 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
652 await Task.Run(() => callback(LLMObject, json, stringWrapper));
653 string result = llmlib?.GetStringWrapperResult(stringWrapper);
654 llmlib?.StringWrapper_Delete(stringWrapper);
667 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
669 llmlib.LLM_Tokenize(LLMObject, jsonData, strWrapper);
671 return await LLMReply(callback, json);
682 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
684 llmlib.LLM_Detokenize(LLMObject, jsonData, strWrapper);
686 return await LLMReply(callback, json);
697 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
699 llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
701 return await LLMReply(callback, json);
710 LoraWeightRequestList loraWeightRequest =
new LoraWeightRequestList();
711 loraWeightRequest.loraWeights =
new List<LoraWeightRequest>();
713 for (
int i = 0; i < weights.Length; i++)
715 loraWeightRequest.loraWeights.Add(
new LoraWeightRequest() {
id = i, scale = weights[i] });
718 string json = JsonUtility.ToJson(loraWeightRequest);
719 int startIndex = json.IndexOf(
"[");
720 int endIndex = json.LastIndexOf(
"]") + 1;
721 json = json.Substring(startIndex, endIndex - startIndex);
723 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
724 llmlib.LLM_Lora_Weight(LLMObject, json, stringWrapper);
725 llmlib.StringWrapper_Delete(stringWrapper);
735 LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
737 llmlib.LLM_LoraList(LLMObject, strWrapper);
739 string json = await LLMNoInputReply(callback);
740 if (String.IsNullOrEmpty(json))
return null;
741 LoraWeightResultList loraRequest = JsonUtility.FromJson<LoraWeightResultList>(
"{\"loraWeights\": " + json +
"}");
742 return loraRequest.loraWeights;
750 public async Task<string>
Slot(
string json)
753 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
755 llmlib.LLM_Slot(LLMObject, jsonData, strWrapper);
757 return await LLMReply(callback, json);
766 public async Task<string>
Completion(
string json, Callback<string> streamCallback =
null)
769 if (streamCallback ==
null) streamCallback = (
string s) => {};
770 StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback);
771 await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper()));
773 streamWrapper.Update();
774 string result = streamWrapper.GetString();
775 DestroyStreamWrapper(streamWrapper);
780 public async Task SetBasePrompt(
string base_prompt)
783 SystemPromptRequest request =
new SystemPromptRequest() { system_prompt = base_prompt, prompt =
" ", n_predict = 0 };
784 await
Completion(JsonUtility.ToJson(request));
794 llmlib?.LLM_Cancel(LLMObject, id_slot);
810 if (LLMObject != IntPtr.Zero)
812 llmlib.LLM_Stop(LLMObject);
813 if (
remote) llmlib.LLM_StopServer(LLMObject);
816 llmlib.LLM_Delete(LLMObject);
817 LLMObject = IntPtr.Zero;
Class implementing the skeleton of a chat template.
static string DefaultTemplate
the default template used when it can't be determined ("chatml")
Class implementing calling of LLM functions (local and remote).
Class implementing the LLM model manager.
static void Unregister(LLM llm)
Removes a LLM from the model manager.
static ModelEntry Get(string path)
Gets the model entry for a model path.
static Task< bool > Setup()
Setup of the models.
static void Register(LLM llm)
Registers a LLM to the model manager.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
int numGPULayers
number of model layers to offload to the GPU (0 = GPU not used). Use a large number i....
void ApplyLoras()
Sets the lora scale, only works after the LLM service has started.
async Task< string > Slot(string json)
Allows to save / restore the state of a slot.
void SetLoraWeights(Dictionary< string, float > loraToWeight)
Allows to change the weights (scale) of the LORA models in the LLM.
async Task< List< LoraWeightResult > > ListLoras()
Gets a list of the lora adapters.
static async Task< bool > WaitUntilModelSetup(Callback< float > downloadProgressCallback=null)
Allows to wait until the LLM models are downloaded and ready.
string GetTemplate()
Returns the chat template of the LLM.
void SetLoraWeight(string path, float weight)
Allows to change the weight (scale) of a LORA model in the LLM.
void CancelRequest(int id_slot)
Allows to cancel the requests in a specific slot of the LLM.
int parallelPrompts
number of prompts that can happen in parallel (-1 = number of LLMCaller objects)
bool debug
select to log the output of the LLM in the Unity Editor.
string basePrompt
a base prompt to use as a base for all LLMCaller objects
async void Awake()
The Unity Awake function that starts the LLM server.
async Task< string > Detokenize(string json)
Detokenises the provided query.
void OnDestroy()
The Unity OnDestroy function called when the onbject is destroyed. The function StopProcess is called...
void SetLora(string path, float weight=1)
Allows to set a LORA model to use in the LLM. The model provided is copied to the Assets/StreamingAss...
void AddLora(string path, float weight=1)
Allows to add a LORA model to use in the LLM. The model provided is copied to the Assets/StreamingAss...
bool advancedOptions
toggle to show/hide advanced options in the GameObject
void RemoveLora(string path)
Allows to remove a LORA model from the LLM. Models supported are in .gguf format.
static bool modelSetupFailed
Boolean set to true if the models were not downloaded successfully.
string lora
the paths of the LORA models being used (relative to the Assets/StreamingAssets folder)....
int contextSize
Size of the prompt context (0 = context size of the model). This is the number of tokens the model ca...
int numThreads
number of threads to use (-1 = all)
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
void SetModel(string path)
Allows to set the model used by the LLM. The model provided is copied to the Assets/StreamingAssets f...
int port
port to use for the LLM server
bool remote
toggle to enable remote server functionality
void SetSSLCert(string path)
Use a SSL certificate for the LLM server.
void RemoveLoras()
Allows to remove all LORA models from the LLM.
bool dontDestroyOnLoad
select to not destroy the LLM GameObject when loading a new Scene.
void SetEmbeddings(int embeddingLength, bool embeddingsOnly)
Set LLM Embedding parameters.
string model
the LLM model to use. Models with .gguf format are allowed.
string APIKey
API key to use for the server (optional)
async Task< string > Tokenize(string json)
Tokenises the provided query.
int batchSize
Batch size for prompt processing.
void SetSSLKey(string path)
Use a SSL key for the LLM server.
string chatTemplate
Chat template used for the model.
bool flashAttention
enable use of flash attention
async Task< string > Completion(string json, Callback< string > streamCallback=null)
Allows to use the chat and completion functionality of the LLM.
static bool modelSetupComplete
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
async Task WaitUntilReady()
Allows to wait until the LLM is ready.
void SetTemplate(string templateName, bool setDirty=true)
Set the chat template for the LLM.
void Destroy()
Stops and destroys the LLM.
void Update()
The Unity Update function. It is used to retrieve the LLM replies.
string loraWeights
the weights of the LORA models being used.
int Register(LLMCaller llmCaller)
Registers a local LLMCaller object. This allows to bind the LLMCaller "client" to a specific slot of ...
bool failed
Boolean set to true if the server has failed to start.
async Task< string > Embeddings(string json)
Computes the embeddings of the provided query.
Class representing the LORA manager allowing to convert and retrieve LORA assets to string (for seria...
float[] GetWeights()
Gets the weights of the LORAs in the manager.
void Add(string path, float weight=1)
Adds a LORA with the defined weight.
void Remove(string path)
Removes a LORA based on its path.
void SetWeight(string path, float weight)
Modifies the weight of a LORA.
void FromStrings(string loraString, string loraWeightsString)
Converts strings with the lora paths and weights to entries in the LORA manager.
string[] GetLoras()
Gets the paths of the LORAs in the manager.
void Clear()
Clears the LORA assets.
Class implementing a LLM model entry.