4using System.Collections.Generic;
7using System.Threading.Tasks;
13 [DefaultExecutionOrder(-1)]
18 public class LLM : MonoBehaviour
21 [Tooltip(
"show/hide advanced options in the GameObject")]
24 [Tooltip(
"enable remote server functionality")]
25 [LocalRemote]
public bool remote =
false;
27 [Tooltip(
"port to use for the remote LLM server")]
28 [Remote]
public int port = 13333;
30 [Tooltip(
"number of threads to use (-1 = all)")]
34 [Tooltip(
"number of model layers to offload to the GPU (0 = GPU not used). If the user's GPU is not supported, the LLM will fall back to the CPU")]
37 [Tooltip(
"log the output of the LLM in the Unity Editor.")]
40 [Tooltip(
"number of prompts that can happen in parallel (-1 = number of LLMCaller objects)")]
43 [Tooltip(
"do not destroy the LLM GameObject when loading a new Scene.")]
47 [Tooltip(
"Size of the prompt context (0 = context size of the model). This is the number of tokens the model can take as input when generating responses.")]
48 [DynamicRange(
"minContextLength",
"maxContextLength",
false), Model]
public int contextSize = 8192;
50 [Tooltip(
"Batch size for prompt processing.")]
53 public bool started {
get;
protected set; } =
false;
55 public bool failed {
get;
protected set; } =
false;
61 [Tooltip(
"LLM model to use (.gguf format)")]
62 [ModelAdvanced]
public string model =
"";
64 [Tooltip(
"Chat template for the model")]
67 [Tooltip(
"LORA models to use (.gguf format)")]
68 [ModelAdvanced]
public string lora =
"";
70 [Tooltip(
"the weights of the LORA models being used.")]
73 [Tooltip(
"enable use of flash attention")]
76 [Tooltip(
"API key to use for the server")]
81 private string SSLCert =
"";
82 public string SSLCertPath =
"";
85 private string SSLKey =
"";
86 public string SSLKeyPath =
"";
89 public int minContextLength = 0;
90 public int maxContextLength = 0;
91 public string architecture => llmlib.architecture;
93 IntPtr LLMObject = IntPtr.Zero;
94 List<LLMCaller> clients =
new List<LLMCaller>();
96 StreamWrapper logStreamWrapper =
null;
97 Thread llmThread =
null;
98 List<StreamWrapper> streamWrappers =
new List<StreamWrapper>();
100 private readonly
object startLock =
new object();
101 static readonly
object staticLock =
new object();
104 string loraWeightsPre =
"";
105 public bool embeddingsOnly =
false;
106 public int embeddingLength = 0;
129 if (!enabled)
return;
139 string arguments = GetLlamaccpArguments();
140 if (arguments ==
null)
145 await Task.Run(() => StartLLMServer(arguments));
155 while (!
started) await Task.Yield();
164 if (downloadProgressCallback !=
null)
LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback);
170 public static string GetLLMManagerAsset(
string path)
173 if (!EditorApplication.isPlaying)
return GetLLMManagerAssetEditor(path);
175 return GetLLMManagerAssetRuntime(path);
178 public static string GetLLMManagerAssetEditor(
string path)
181 if (
string.IsNullOrEmpty(path))
return path;
183 ModelEntry modelEntry = LLMManager.Get(path);
184 if (modelEntry !=
null)
return modelEntry.filename;
186 string assetPath = LLMUnitySetup.GetAssetPath(path);
187 string basePath = LLMUnitySetup.GetAssetPath();
188 if (File.Exists(assetPath))
190 if (LLMUnitySetup.IsSubPath(assetPath, basePath))
return LLMUnitySetup.RelativePath(assetPath, basePath);
193 if (!File.Exists(assetPath))
195 LLMUnitySetup.LogError($
"Model {path} was not found.");
199 string errorMessage = $
"The model {path} was loaded locally. You can include it in the build in one of these ways:";
200 errorMessage += $
"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path";
201 errorMessage += $
"\n-Load the model with the model manager inside the LLM GameObject and use its filename";
202 LLMUnitySetup.LogWarning(errorMessage);
207 public static string GetLLMManagerAssetRuntime(
string path)
210 if (
string.IsNullOrEmpty(path))
return path;
212 string managerPath = LLMManager.GetAssetPath(path);
213 if (!
string.IsNullOrEmpty(managerPath) && File.Exists(managerPath))
return managerPath;
215 string assetPath = LLMUnitySetup.GetAssetPath(path);
216 if (File.Exists(assetPath))
return assetPath;
218 assetPath = LLMUnitySetup.GetDownloadAssetPath(path);
219 if (File.Exists(assetPath))
return assetPath;
234 model = GetLLMManagerAsset(path);
235 if (!
string.IsNullOrEmpty(
model))
238 if (modelEntry ==
null) modelEntry =
new ModelEntry(GetLLMManagerAssetRuntime(
model));
241 maxContextLength = modelEntry.contextLength;
243 SetEmbeddings(modelEntry.embeddingLength, modelEntry.embeddingOnly);
244 if (
contextSize == 0 && modelEntry.contextLength > 32768)
246 LLMUnitySetup.LogWarning($
"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM");
250 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
260 public void SetLora(
string path,
float weight = 1)
273 public void AddLora(
string path,
float weight = 1)
276 loraManager.
Add(path, weight);
320 foreach (KeyValuePair<string, float> entry
in loraToWeight) loraManager.
SetWeight(entry.Key, entry.Value);
325 public void UpdateLoras()
330 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
338 public void SetTemplate(
string templateName,
bool setDirty =
true)
343 if (setDirty && !EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
354 this.embeddingsOnly = embeddingsOnly;
355 this.embeddingLength = embeddingLength;
357 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
363 string ReadFileContents(
string path)
365 if (String.IsNullOrEmpty(path))
return "";
366 else if (!File.Exists(path))
371 return File.ReadAllText(path);
383 SSLCert = ReadFileContents(path);
393 SSLKey = ReadFileContents(path);
405 protected virtual string GetLlamaccpArguments()
408 if ((SSLCert !=
"" && SSLKey ==
"") || (SSLCert ==
"" && SSLKey !=
""))
410 LLMUnitySetup.LogError($
"Both SSL certificate and key need to be provided!");
416 LLMUnitySetup.LogError(
"No model file provided!");
419 string modelPath = GetLLMManagerAssetRuntime(
model);
420 if (!File.Exists(modelPath))
422 LLMUnitySetup.LogError($
"File {modelPath} not found!");
427 string loraArgument =
"";
430 string loraPath = GetLLMManagerAssetRuntime(
lora);
431 if (!File.Exists(loraPath))
433 LLMUnitySetup.LogError($
"File {loraPath} not found!");
436 loraArgument += $
" --lora \"{loraPath}\"";
440 if (Application.platform == RuntimePlatform.Android &&
numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores();
442 int slots = GetNumClients();
443 string arguments = $
"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
444 if (embeddingsOnly) arguments +=
" --embedding";
445 if (numThreadsToUse > 0) arguments += $
" -t {numThreadsToUse}";
446 arguments += loraArgument;
447 if (
numGPULayers > 0) arguments += $
" -ngl {numGPULayers}";
448 if (LLMUnitySetup.FullLlamaLib &&
flashAttention) arguments += $
" --flash-attn";
451 arguments += $
" --port {port} --host 0.0.0.0";
452 if (!String.IsNullOrEmpty(
APIKey)) arguments += $
" --api-key {APIKey}";
456 string serverCommand;
457 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) serverCommand =
"undreamai_server.exe";
458 else serverCommand =
"./undreamai_server";
459 serverCommand +=
" " + arguments;
460 serverCommand += $
" --template \"{chatTemplate}\"";
461 if (
remote && SSLCert !=
"" && SSLKey !=
"") serverCommand += $
" --ssl-cert-file {SSLCertPath} --ssl-key-file {SSLKeyPath}";
462 LLMUnitySetup.Log($
"Deploy server command: {serverCommand}");
466 private void SetupLogging()
468 logStreamWrapper = ConstructStreamWrapper(LLMUnitySetup.LogWarning,
true);
469 llmlib?.Logging(logStreamWrapper.GetStringWrapper());
472 private void StopLogging()
474 if (logStreamWrapper ==
null)
return;
475 llmlib?.StopLogging();
476 DestroyStreamWrapper(logStreamWrapper);
479 private void StartLLMServer(
string arguments)
485 foreach (
string arch
in LLMLib.PossibleArchitectures(useGPU))
491 InitService(arguments);
492 LLMUnitySetup.Log($
"Using architecture: {arch}");
495 catch (LLMException e)
500 catch (DestroyException)
506 error = $
"{e.GetType()}: {e.Message}";
508 LLMUnitySetup.Log($
"Tried architecture: {arch}, error: " + error);
512 LLMUnitySetup.LogError(
"LLM service couldn't be created");
516 CallWithLock(StartService);
517 LLMUnitySetup.Log(
"LLM service created");
520 private void InitLib(
string arch)
522 llmlib =
new LLMLib(arch);
523 CheckLLMStatus(
false);
526 void CallWithLock(EmptyCallback fn)
530 if (llmlib ==
null)
throw new DestroyException();
535 private void InitService(
string arguments)
539 if (
debug) CallWithLock(SetupLogging);
540 CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
541 CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject,
chatTemplate));
544 if (SSLCert !=
"" && SSLKey !=
"")
546 LLMUnitySetup.Log(
"Using SSL");
547 CallWithLock(() => llmlib.LLM_SetSSL(LLMObject, SSLCert, SSLKey));
549 CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
551 CallWithLock(() => CheckLLMStatus(
false));
555 private void StartService()
557 llmThread =
new Thread(() => llmlib.LLM_Start(LLMObject));
559 while (!llmlib.LLM_Started(LLMObject)) {}
572 clients.Add(llmCaller);
573 int index = clients.IndexOf(llmCaller);
578 protected int GetNumClients()
584 public delegate
void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
585 public delegate
void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
586 public delegate
void LLMReplyCallback(IntPtr LLMObject,
string json_data, IntPtr stringWrapper);
589 StreamWrapper ConstructStreamWrapper(Callback<string> streamCallback =
null,
bool clearOnUpdate =
false)
591 StreamWrapper streamWrapper =
new StreamWrapper(llmlib, streamCallback, clearOnUpdate);
592 streamWrappers.Add(streamWrapper);
593 return streamWrapper;
596 void DestroyStreamWrapper(StreamWrapper streamWrapper)
598 streamWrappers.Remove(streamWrapper);
599 streamWrapper.Destroy();
606 foreach (StreamWrapper streamWrapper
in streamWrappers) streamWrapper.Update();
612 if (
failed) error =
"LLM service couldn't be created";
613 else if (!
started) error =
"LLM service not started";
617 throw new Exception(error);
621 void AssertNotStarted()
625 string error =
"This method can't be called when the LLM has started";
626 LLMUnitySetup.LogError(error);
627 throw new Exception(error);
631 void CheckLLMStatus(
bool log =
true)
633 if (llmlib ==
null) {
return; }
634 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
635 int status = llmlib.LLM_Status(LLMObject, stringWrapper);
636 string result = llmlib.GetStringWrapperResult(stringWrapper);
637 llmlib.StringWrapper_Delete(stringWrapper);
638 string message = $
"LLM {status}: {result}";
641 if (log) LLMUnitySetup.LogError(message);
642 throw new LLMException(message, status);
646 if (log) LLMUnitySetup.LogWarning(message);
650 async Task<string> LLMNoInputReply(LLMNoInputReplyCallback callback)
653 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
654 await Task.Run(() => callback(LLMObject, stringWrapper));
655 string result = llmlib?.GetStringWrapperResult(stringWrapper);
656 llmlib?.StringWrapper_Delete(stringWrapper);
661 async Task<string> LLMReply(LLMReplyCallback callback,
string json)
664 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
665 await Task.Run(() => callback(LLMObject, json, stringWrapper));
666 string result = llmlib?.GetStringWrapperResult(stringWrapper);
667 llmlib?.StringWrapper_Delete(stringWrapper);
680 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
682 llmlib.LLM_Tokenize(LLMObject, jsonData, strWrapper);
684 return await LLMReply(callback, json);
695 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
697 llmlib.LLM_Detokenize(LLMObject, jsonData, strWrapper);
699 return await LLMReply(callback, json);
710 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
712 llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
714 return await LLMReply(callback, json);
723 LoraWeightRequestList loraWeightRequest =
new LoraWeightRequestList();
724 loraWeightRequest.loraWeights =
new List<LoraWeightRequest>();
726 if (weights.Length == 0)
return;
727 for (
int i = 0; i < weights.Length; i++)
729 loraWeightRequest.loraWeights.Add(
new LoraWeightRequest() {
id = i, scale = weights[i] });
732 string json = JsonUtility.ToJson(loraWeightRequest);
733 int startIndex = json.IndexOf(
"[");
734 int endIndex = json.LastIndexOf(
"]") + 1;
735 json = json.Substring(startIndex, endIndex - startIndex);
737 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
738 llmlib.LLM_LoraWeight(LLMObject, json, stringWrapper);
739 llmlib.StringWrapper_Delete(stringWrapper);
749 LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
751 llmlib.LLM_LoraList(LLMObject, strWrapper);
753 string json = await LLMNoInputReply(callback);
754 if (String.IsNullOrEmpty(json))
return null;
755 LoraWeightResultList loraRequest = JsonUtility.FromJson<LoraWeightResultList>(
"{\"loraWeights\": " + json +
"}");
756 return loraRequest.loraWeights;
764 public async Task<string>
Slot(
string json)
767 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
769 llmlib.LLM_Slot(LLMObject, jsonData, strWrapper);
771 return await LLMReply(callback, json);
780 public async Task<string>
Completion(
string json, Callback<string> streamCallback =
null)
783 if (streamCallback ==
null) streamCallback = (
string s) => {};
784 StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback);
785 await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper()));
787 streamWrapper.Update();
788 string result = streamWrapper.GetString();
789 DestroyStreamWrapper(streamWrapper);
801 llmlib?.LLM_Cancel(LLMObject, id_slot);
817 if (LLMObject != IntPtr.Zero)
819 llmlib.LLM_Stop(LLMObject);
820 if (
remote) llmlib.LLM_StopServer(LLMObject);
823 llmlib.LLM_Delete(LLMObject);
824 LLMObject = IntPtr.Zero;
Class implementing the skeleton of a chat template.
static string DefaultTemplate
the default template used when it can't be determined ("chatml")
Class implementing calling of LLM functions (local and remote).
Class implementing the LLM model manager.
static void Unregister(LLM llm)
Removes a LLM from the model manager.
static ModelEntry Get(string path)
Gets the model entry for a model path.
static Task< bool > Setup()
Setup of the models.
static void Register(LLM llm)
Registers a LLM to the model manager.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
int numGPULayers
number of model layers to offload to the GPU (0 = GPU not used). If the user's GPU is not supported,...
void ApplyLoras()
Sets the lora scale, only works after the LLM service has started.
async Task< string > Slot(string json)
Allows to save / restore the state of a slot.
void SetLoraWeights(Dictionary< string, float > loraToWeight)
Allows to change the weights (scale) of the LORA models in the LLM.
async Task< List< LoraWeightResult > > ListLoras()
Gets a list of the lora adapters.
static async Task< bool > WaitUntilModelSetup(Callback< float > downloadProgressCallback=null)
Allows to wait until the LLM models are downloaded and ready.
string GetTemplate()
Returns the chat template of the LLM.
void SetLoraWeight(string path, float weight)
Allows to change the weight (scale) of a LORA model in the LLM.
void CancelRequest(int id_slot)
Allows to cancel the requests in a specific slot of the LLM.
int parallelPrompts
number of prompts that can happen in parallel (-1 = number of LLMCaller objects)
bool debug
log the output of the LLM in the Unity Editor.
async void Awake()
The Unity Awake function that starts the LLM server.
async Task< string > Detokenize(string json)
Detokenises the provided query.
void OnDestroy()
The Unity OnDestroy function called when the onbject is destroyed. The function StopProcess is called...
void SetLora(string path, float weight=1)
Allows to set a LORA model to use in the LLM. The model provided is copied to the Assets/StreamingAss...
void AddLora(string path, float weight=1)
Allows to add a LORA model to use in the LLM. The model provided is copied to the Assets/StreamingAss...
bool advancedOptions
show/hide advanced options in the GameObject
void RemoveLora(string path)
Allows to remove a LORA model from the LLM. Models supported are in .gguf format.
static bool modelSetupFailed
Boolean set to true if the models were not downloaded successfully.
string lora
LORA models to use (.gguf format)
int contextSize
Size of the prompt context (0 = context size of the model). This is the number of tokens the model ca...
int numThreads
number of threads to use (-1 = all)
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
void SetModel(string path)
Allows to set the model used by the LLM. The model provided is copied to the Assets/StreamingAssets f...
int port
port to use for the remote LLM server
bool remote
enable remote server functionality
void SetSSLCert(string path)
Use a SSL certificate for the LLM server.
void RemoveLoras()
Allows to remove all LORA models from the LLM.
bool dontDestroyOnLoad
do not destroy the LLM GameObject when loading a new Scene.
void SetEmbeddings(int embeddingLength, bool embeddingsOnly)
Set LLM Embedding parameters.
string model
LLM model to use (.gguf format)
string APIKey
API key to use for the server.
async Task< string > Tokenize(string json)
Tokenises the provided query.
int batchSize
Batch size for prompt processing.
void SetSSLKey(string path)
Use a SSL key for the LLM server.
string chatTemplate
Chat template for the model.
bool flashAttention
enable use of flash attention
async Task< string > Completion(string json, Callback< string > streamCallback=null)
Allows to use the chat and completion functionality of the LLM.
static bool modelSetupComplete
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
async Task WaitUntilReady()
Allows to wait until the LLM is ready.
void SetTemplate(string templateName, bool setDirty=true)
Set the chat template for the LLM.
void Destroy()
Stops and destroys the LLM.
void Update()
The Unity Update function. It is used to retrieve the LLM replies.
string loraWeights
the weights of the LORA models being used.
int Register(LLMCaller llmCaller)
Registers a local LLMCaller object. This allows to bind the LLMCaller "client" to a specific slot of ...
bool failed
Boolean set to true if the server has failed to start.
async Task< string > Embeddings(string json)
Computes the embeddings of the provided query.
Class representing the LORA manager allowing to convert and retrieve LORA assets to string (for seria...
float[] GetWeights()
Gets the weights of the LORAs in the manager.
void Add(string path, float weight=1)
Adds a LORA with the defined weight.
void Remove(string path)
Removes a LORA based on its path.
void SetWeight(string path, float weight)
Modifies the weight of a LORA.
void FromStrings(string loraString, string loraWeightsString)
Converts strings with the lora paths and weights to entries in the LORA manager.
string[] GetLoras()
Gets the paths of the LORAs in the manager.
void Clear()
Clears the LORA assets.
Class implementing a LLM model entry.