4using System.Collections.Generic;
7using System.Threading.Tasks;
13 [DefaultExecutionOrder(-1)]
18 public class LLM : MonoBehaviour
21 [Tooltip(
"show/hide advanced options in the GameObject")]
24 [Tooltip(
"enable remote server functionality")]
25 [LocalRemote]
public bool remote =
false;
27 [Tooltip(
"port to use for the remote LLM server")]
28 [Remote]
public int port = 13333;
30 [Tooltip(
"number of threads to use (-1 = all)")]
34 [Tooltip(
"number of model layers to offload to the GPU (0 = GPU not used). If the user's GPU is not supported, the LLM will fall back to the CPU")]
37 [Tooltip(
"log the output of the LLM in the Unity Editor.")]
40 [Tooltip(
"number of prompts that can happen in parallel (-1 = number of LLMCaller objects)")]
43 [Tooltip(
"do not destroy the LLM GameObject when loading a new Scene.")]
47 [Tooltip(
"Size of the prompt context (0 = context size of the model). This is the number of tokens the model can take as input when generating responses.")]
48 [DynamicRange(
"minContextLength",
"maxContextLength",
false), Model]
public int contextSize = 8192;
50 [Tooltip(
"Batch size for prompt processing.")]
53 public bool started {
get;
protected set; } =
false;
55 public bool failed {
get;
protected set; } =
false;
61 [Tooltip(
"LLM model to use (.gguf format)")]
62 [ModelAdvanced]
public string model =
"";
64 [Tooltip(
"Chat template for the model")]
67 [Tooltip(
"LORA models to use (.gguf format)")]
68 [ModelAdvanced]
public string lora =
"";
70 [Tooltip(
"the weights of the LORA models being used.")]
73 [Tooltip(
"enable use of flash attention")]
76 [Tooltip(
"API key to use for the server")]
81 private string SSLCert =
"";
82 public string SSLCertPath =
"";
85 private string SSLKey =
"";
86 public string SSLKeyPath =
"";
89 public int minContextLength = 0;
90 public int maxContextLength = 0;
92 IntPtr LLMObject = IntPtr.Zero;
93 List<LLMCaller> clients =
new List<LLMCaller>();
95 StreamWrapper logStreamWrapper =
null;
96 Thread llmThread =
null;
97 List<StreamWrapper> streamWrappers =
new List<StreamWrapper>();
99 private readonly
object startLock =
new object();
100 static readonly
object staticLock =
new object();
103 string loraWeightsPre =
"";
104 public bool embeddingsOnly =
false;
105 public int embeddingLength = 0;
128 if (!enabled)
return;
138 string arguments = GetLlamaccpArguments();
139 if (arguments ==
null)
144 await Task.Run(() => StartLLMServer(arguments));
154 while (!
started) await Task.Yield();
163 if (downloadProgressCallback !=
null)
LLMManager.downloadProgressCallbacks.Add(downloadProgressCallback);
169 public static string GetLLMManagerAsset(
string path)
172 if (!EditorApplication.isPlaying)
return GetLLMManagerAssetEditor(path);
174 return GetLLMManagerAssetRuntime(path);
177 public static string GetLLMManagerAssetEditor(
string path)
180 if (
string.IsNullOrEmpty(path))
return path;
182 ModelEntry modelEntry = LLMManager.Get(path);
183 if (modelEntry !=
null)
return modelEntry.filename;
185 string assetPath = LLMUnitySetup.GetAssetPath(path);
186 string basePath = LLMUnitySetup.GetAssetPath();
187 if (File.Exists(assetPath))
189 if (LLMUnitySetup.IsSubPath(assetPath, basePath))
return LLMUnitySetup.RelativePath(assetPath, basePath);
192 if (!File.Exists(assetPath))
194 LLMUnitySetup.LogError($
"Model {path} was not found.");
198 string errorMessage = $
"The model {path} was loaded locally. You can include it in the build in one of these ways:";
199 errorMessage += $
"\n-Copy the model inside the StreamingAssets folder and use its StreamingAssets path";
200 errorMessage += $
"\n-Load the model with the model manager inside the LLM GameObject and use its filename";
201 LLMUnitySetup.LogWarning(errorMessage);
206 public static string GetLLMManagerAssetRuntime(
string path)
209 if (
string.IsNullOrEmpty(path))
return path;
211 string managerPath = LLMManager.GetAssetPath(path);
212 if (!
string.IsNullOrEmpty(managerPath) && File.Exists(managerPath))
return managerPath;
214 string assetPath = LLMUnitySetup.GetAssetPath(path);
215 if (File.Exists(assetPath))
return assetPath;
217 assetPath = LLMUnitySetup.GetDownloadAssetPath(path);
218 if (File.Exists(assetPath))
return assetPath;
233 model = GetLLMManagerAsset(path);
234 if (!
string.IsNullOrEmpty(
model))
237 if (modelEntry ==
null) modelEntry =
new ModelEntry(GetLLMManagerAssetRuntime(
model));
240 maxContextLength = modelEntry.contextLength;
242 SetEmbeddings(modelEntry.embeddingLength, modelEntry.embeddingOnly);
243 if (
contextSize == 0 && modelEntry.contextLength > 32768)
245 LLMUnitySetup.LogWarning($
"The model {path} has very large context size ({modelEntry.contextLength}), consider setting it to a smaller value (<=32768) to avoid filling up the RAM");
249 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
259 public void SetLora(
string path,
float weight = 1)
272 public void AddLora(
string path,
float weight = 1)
275 loraManager.
Add(path, weight);
319 foreach (KeyValuePair<string, float> entry
in loraToWeight) loraManager.
SetWeight(entry.Key, entry.Value);
324 public void UpdateLoras()
329 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
337 public void SetTemplate(
string templateName,
bool setDirty =
true)
342 if (setDirty && !EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
353 this.embeddingsOnly = embeddingsOnly;
354 this.embeddingLength = embeddingLength;
356 if (!EditorApplication.isPlaying) EditorUtility.SetDirty(
this);
362 string ReadFileContents(
string path)
364 if (String.IsNullOrEmpty(path))
return "";
365 else if (!File.Exists(path))
370 return File.ReadAllText(path);
382 SSLCert = ReadFileContents(path);
392 SSLKey = ReadFileContents(path);
404 protected virtual string GetLlamaccpArguments()
407 if ((SSLCert !=
"" && SSLKey ==
"") || (SSLCert ==
"" && SSLKey !=
""))
409 LLMUnitySetup.LogError($
"Both SSL certificate and key need to be provided!");
415 LLMUnitySetup.LogError(
"No model file provided!");
418 string modelPath = GetLLMManagerAssetRuntime(
model);
419 if (!File.Exists(modelPath))
421 LLMUnitySetup.LogError($
"File {modelPath} not found!");
426 string loraArgument =
"";
429 string loraPath = GetLLMManagerAssetRuntime(
lora);
430 if (!File.Exists(loraPath))
432 LLMUnitySetup.LogError($
"File {loraPath} not found!");
435 loraArgument += $
" --lora \"{loraPath}\"";
439 if (Application.platform == RuntimePlatform.Android &&
numThreads <= 0) numThreadsToUse = LLMUnitySetup.AndroidGetNumBigCores();
441 int slots = GetNumClients();
442 string arguments = $
"-m \"{modelPath}\" -c {contextSize} -b {batchSize} --log-disable -np {slots}";
443 if (embeddingsOnly) arguments +=
" --embedding";
444 if (numThreadsToUse > 0) arguments += $
" -t {numThreadsToUse}";
445 arguments += loraArgument;
446 arguments += $
" -ngl {numGPULayers}";
447 if (LLMUnitySetup.FullLlamaLib &&
flashAttention) arguments += $
" --flash-attn";
450 arguments += $
" --port {port} --host 0.0.0.0";
451 if (!String.IsNullOrEmpty(
APIKey)) arguments += $
" --api-key {APIKey}";
455 string serverCommand;
456 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer) serverCommand =
"undreamai_server.exe";
457 else serverCommand =
"./undreamai_server";
458 serverCommand +=
" " + arguments;
459 serverCommand += $
" --template \"{chatTemplate}\"";
460 if (
remote && SSLCert !=
"" && SSLKey !=
"") serverCommand += $
" --ssl-cert-file {SSLCertPath} --ssl-key-file {SSLKeyPath}";
461 LLMUnitySetup.Log($
"Deploy server command: {serverCommand}");
465 private void SetupLogging()
467 logStreamWrapper = ConstructStreamWrapper(LLMUnitySetup.LogWarning,
true);
468 llmlib?.Logging(logStreamWrapper.GetStringWrapper());
471 private void StopLogging()
473 if (logStreamWrapper ==
null)
return;
474 llmlib?.StopLogging();
475 DestroyStreamWrapper(logStreamWrapper);
478 private void StartLLMServer(
string arguments)
484 foreach (
string arch
in LLMLib.PossibleArchitectures(useGPU))
490 InitService(arguments);
491 LLMUnitySetup.Log($
"Using architecture: {arch}");
494 catch (LLMException e)
499 catch (DestroyException)
505 error = $
"{e.GetType()}: {e.Message}";
507 LLMUnitySetup.Log($
"Tried architecture: {arch}, error: " + error);
511 LLMUnitySetup.LogError(
"LLM service couldn't be created");
515 CallWithLock(StartService);
516 LLMUnitySetup.Log(
"LLM service created");
519 private void InitLib(
string arch)
521 llmlib =
new LLMLib(arch);
522 CheckLLMStatus(
false);
525 void CallWithLock(EmptyCallback fn)
529 if (llmlib ==
null)
throw new DestroyException();
534 private void InitService(
string arguments)
538 if (
debug) CallWithLock(SetupLogging);
539 CallWithLock(() => { LLMObject = llmlib.LLM_Construct(arguments); });
540 CallWithLock(() => llmlib.LLM_SetTemplate(LLMObject,
chatTemplate));
543 if (SSLCert !=
"" && SSLKey !=
"")
545 LLMUnitySetup.Log(
"Using SSL");
546 CallWithLock(() => llmlib.LLM_SetSSL(LLMObject, SSLCert, SSLKey));
548 CallWithLock(() => llmlib.LLM_StartServer(LLMObject));
550 CallWithLock(() => CheckLLMStatus(
false));
554 private void StartService()
556 llmThread =
new Thread(() => llmlib.LLM_Start(LLMObject));
558 while (!llmlib.LLM_Started(LLMObject)) {}
571 clients.Add(llmCaller);
572 int index = clients.IndexOf(llmCaller);
577 protected int GetNumClients()
583 public delegate
void LLMStatusCallback(IntPtr LLMObject, IntPtr stringWrapper);
584 public delegate
void LLMNoInputReplyCallback(IntPtr LLMObject, IntPtr stringWrapper);
585 public delegate
void LLMReplyCallback(IntPtr LLMObject,
string json_data, IntPtr stringWrapper);
588 StreamWrapper ConstructStreamWrapper(Callback<string> streamCallback =
null,
bool clearOnUpdate =
false)
590 StreamWrapper streamWrapper =
new StreamWrapper(llmlib, streamCallback, clearOnUpdate);
591 streamWrappers.Add(streamWrapper);
592 return streamWrapper;
595 void DestroyStreamWrapper(StreamWrapper streamWrapper)
597 streamWrappers.Remove(streamWrapper);
598 streamWrapper.Destroy();
605 foreach (StreamWrapper streamWrapper
in streamWrappers) streamWrapper.Update();
611 if (
failed) error =
"LLM service couldn't be created";
612 else if (!
started) error =
"LLM service not started";
616 throw new Exception(error);
620 void AssertNotStarted()
624 string error =
"This method can't be called when the LLM has started";
625 LLMUnitySetup.LogError(error);
626 throw new Exception(error);
630 void CheckLLMStatus(
bool log =
true)
632 if (llmlib ==
null) {
return; }
633 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
634 int status = llmlib.LLM_Status(LLMObject, stringWrapper);
635 string result = llmlib.GetStringWrapperResult(stringWrapper);
636 llmlib.StringWrapper_Delete(stringWrapper);
637 string message = $
"LLM {status}: {result}";
640 if (log) LLMUnitySetup.LogError(message);
641 throw new LLMException(message, status);
645 if (log) LLMUnitySetup.LogWarning(message);
649 async Task<string> LLMNoInputReply(LLMNoInputReplyCallback callback)
652 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
653 await Task.Run(() => callback(LLMObject, stringWrapper));
654 string result = llmlib?.GetStringWrapperResult(stringWrapper);
655 llmlib?.StringWrapper_Delete(stringWrapper);
660 async Task<string> LLMReply(LLMReplyCallback callback,
string json)
663 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
664 await Task.Run(() => callback(LLMObject, json, stringWrapper));
665 string result = llmlib?.GetStringWrapperResult(stringWrapper);
666 llmlib?.StringWrapper_Delete(stringWrapper);
679 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
681 llmlib.LLM_Tokenize(LLMObject, jsonData, strWrapper);
683 return await LLMReply(callback, json);
694 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
696 llmlib.LLM_Detokenize(LLMObject, jsonData, strWrapper);
698 return await LLMReply(callback, json);
709 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
711 llmlib.LLM_Embeddings(LLMObject, jsonData, strWrapper);
713 return await LLMReply(callback, json);
722 LoraWeightRequestList loraWeightRequest =
new LoraWeightRequestList();
723 loraWeightRequest.loraWeights =
new List<LoraWeightRequest>();
725 if (weights.Length == 0)
return;
726 for (
int i = 0; i < weights.Length; i++)
728 loraWeightRequest.loraWeights.Add(
new LoraWeightRequest() {
id = i, scale = weights[i] });
731 string json = JsonUtility.ToJson(loraWeightRequest);
732 int startIndex = json.IndexOf(
"[");
733 int endIndex = json.LastIndexOf(
"]") + 1;
734 json = json.Substring(startIndex, endIndex - startIndex);
736 IntPtr stringWrapper = llmlib.StringWrapper_Construct();
737 llmlib.LLM_LoraWeight(LLMObject, json, stringWrapper);
738 llmlib.StringWrapper_Delete(stringWrapper);
748 LLMNoInputReplyCallback callback = (IntPtr LLMObject, IntPtr strWrapper) =>
750 llmlib.LLM_LoraList(LLMObject, strWrapper);
752 string json = await LLMNoInputReply(callback);
753 if (String.IsNullOrEmpty(json))
return null;
754 LoraWeightResultList loraRequest = JsonUtility.FromJson<LoraWeightResultList>(
"{\"loraWeights\": " + json +
"}");
755 return loraRequest.loraWeights;
763 public async Task<string>
Slot(
string json)
766 LLMReplyCallback callback = (IntPtr LLMObject,
string jsonData, IntPtr strWrapper) =>
768 llmlib.LLM_Slot(LLMObject, jsonData, strWrapper);
770 return await LLMReply(callback, json);
779 public async Task<string>
Completion(
string json, Callback<string> streamCallback =
null)
782 if (streamCallback ==
null) streamCallback = (
string s) => {};
783 StreamWrapper streamWrapper = ConstructStreamWrapper(streamCallback);
784 await Task.Run(() => llmlib.LLM_Completion(LLMObject, json, streamWrapper.GetStringWrapper()));
786 streamWrapper.Update();
787 string result = streamWrapper.GetString();
788 DestroyStreamWrapper(streamWrapper);
800 llmlib?.LLM_Cancel(LLMObject, id_slot);
816 if (LLMObject != IntPtr.Zero)
818 llmlib.LLM_Stop(LLMObject);
819 if (
remote) llmlib.LLM_StopServer(LLMObject);
822 llmlib.LLM_Delete(LLMObject);
823 LLMObject = IntPtr.Zero;
Class implementing the skeleton of a chat template.
static string DefaultTemplate
the default template used when it can't be determined ("chatml")
Class implementing calling of LLM functions (local and remote).
Class implementing the LLM model manager.
static void Unregister(LLM llm)
Removes a LLM from the model manager.
static ModelEntry Get(string path)
Gets the model entry for a model path.
static Task< bool > Setup()
Setup of the models.
static void Register(LLM llm)
Registers a LLM to the model manager.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
int numGPULayers
number of model layers to offload to the GPU (0 = GPU not used). If the user's GPU is not supported,...
void ApplyLoras()
Sets the lora scale, only works after the LLM service has started.
async Task< string > Slot(string json)
Allows to save / restore the state of a slot.
void SetLoraWeights(Dictionary< string, float > loraToWeight)
Allows to change the weights (scale) of the LORA models in the LLM.
async Task< List< LoraWeightResult > > ListLoras()
Gets a list of the lora adapters.
static async Task< bool > WaitUntilModelSetup(Callback< float > downloadProgressCallback=null)
Allows to wait until the LLM models are downloaded and ready.
string GetTemplate()
Returns the chat template of the LLM.
void SetLoraWeight(string path, float weight)
Allows to change the weight (scale) of a LORA model in the LLM.
void CancelRequest(int id_slot)
Allows to cancel the requests in a specific slot of the LLM.
int parallelPrompts
number of prompts that can happen in parallel (-1 = number of LLMCaller objects)
bool debug
log the output of the LLM in the Unity Editor.
async void Awake()
The Unity Awake function that starts the LLM server.
async Task< string > Detokenize(string json)
Detokenises the provided query.
void OnDestroy()
The Unity OnDestroy function called when the onbject is destroyed. The function StopProcess is called...
void SetLora(string path, float weight=1)
Allows to set a LORA model to use in the LLM. The model provided is copied to the Assets/StreamingAss...
void AddLora(string path, float weight=1)
Allows to add a LORA model to use in the LLM. The model provided is copied to the Assets/StreamingAss...
bool advancedOptions
show/hide advanced options in the GameObject
void RemoveLora(string path)
Allows to remove a LORA model from the LLM. Models supported are in .gguf format.
static bool modelSetupFailed
Boolean set to true if the models were not downloaded successfully.
string lora
LORA models to use (.gguf format)
int contextSize
Size of the prompt context (0 = context size of the model). This is the number of tokens the model ca...
int numThreads
number of threads to use (-1 = all)
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
void SetModel(string path)
Allows to set the model used by the LLM. The model provided is copied to the Assets/StreamingAssets f...
int port
port to use for the remote LLM server
bool remote
enable remote server functionality
void SetSSLCert(string path)
Use a SSL certificate for the LLM server.
void RemoveLoras()
Allows to remove all LORA models from the LLM.
bool dontDestroyOnLoad
do not destroy the LLM GameObject when loading a new Scene.
void SetEmbeddings(int embeddingLength, bool embeddingsOnly)
Set LLM Embedding parameters.
string model
LLM model to use (.gguf format)
string APIKey
API key to use for the server.
async Task< string > Tokenize(string json)
Tokenises the provided query.
int batchSize
Batch size for prompt processing.
void SetSSLKey(string path)
Use a SSL key for the LLM server.
string chatTemplate
Chat template for the model.
bool flashAttention
enable use of flash attention
async Task< string > Completion(string json, Callback< string > streamCallback=null)
Allows to use the chat and completion functionality of the LLM.
static bool modelSetupComplete
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
async Task WaitUntilReady()
Allows to wait until the LLM is ready.
void SetTemplate(string templateName, bool setDirty=true)
Set the chat template for the LLM.
void Destroy()
Stops and destroys the LLM.
void Update()
The Unity Update function. It is used to retrieve the LLM replies.
string loraWeights
the weights of the LORA models being used.
int Register(LLMCaller llmCaller)
Registers a local LLMCaller object. This allows to bind the LLMCaller "client" to a specific slot of ...
bool failed
Boolean set to true if the server has failed to start.
async Task< string > Embeddings(string json)
Computes the embeddings of the provided query.
Class representing the LORA manager allowing to convert and retrieve LORA assets to string (for seria...
float[] GetWeights()
Gets the weights of the LORAs in the manager.
void Add(string path, float weight=1)
Adds a LORA with the defined weight.
void Remove(string path)
Removes a LORA based on its path.
void SetWeight(string path, float weight)
Modifies the weight of a LORA.
void FromStrings(string loraString, string loraWeightsString)
Converts strings with the lora paths and weights to entries in the LORA manager.
string[] GetLoras()
Gets the paths of the LORAs in the manager.
void Clear()
Clears the LORA assets.
Class implementing a LLM model entry.