LLM for Unity  v2.4.2
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMManager.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading.Tasks;
7using UnityEditor;
8using UnityEngine;
9
10namespace LLMUnity
11{
12 [Serializable]
17 public class ModelEntry
18 {
19 public string label;
20 public string filename;
21 public string path;
22 public bool lora;
23 public string chatTemplate;
24 public string url;
25 public bool embeddingOnly;
26 public int embeddingLength;
27 public bool includeInBuild;
28 public int contextLength;
29
30 static List<string> embeddingOnlyArchs = new List<string> {"bert", "nomic-bert", "jina-bert-v2", "t5", "t5encoder"};
31
37 public static string GetFilenameOrRelativeAssetPath(string path)
38 {
39 string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
40 string basePath = LLMUnitySetup.GetAssetPath();
41 if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath))
42 {
43 return LLMUnitySetup.RelativePath(assetPath, basePath);
44 }
45 return Path.GetFileName(path);
46 }
47
55 public ModelEntry(string path, bool lora = false, string label = null, string url = null)
56 {
57 filename = GetFilenameOrRelativeAssetPath(path);
58 this.label = label == null ? Path.GetFileName(filename) : label;
59 this.lora = lora;
60 this.path = LLMUnitySetup.GetFullPath(path);
61 this.url = url;
62 includeInBuild = true;
63 chatTemplate = null;
64 contextLength = -1;
65 embeddingOnly = false;
66 embeddingLength = 0;
67 if (!lora)
68 {
69 GGUFReader reader = new GGUFReader(this.path);
70 string arch = reader.GetStringField("general.architecture");
71 if (arch != null)
72 {
73 contextLength = reader.GetIntField($"{arch}.context_length");
74 embeddingLength = reader.GetIntField($"{arch}.embedding_length");
75 }
76 embeddingOnly = embeddingOnlyArchs.Contains(arch);
77 chatTemplate = embeddingOnly ? default : ChatTemplate.FromGGUF(reader, this.path);
78 }
79 }
80
86 {
87 ModelEntry entry = (ModelEntry)MemberwiseClone();
88 entry.label = null;
89 entry.path = entry.filename;
90 return entry;
91 }
92 }
93
95 [Serializable]
96 public class LLMManagerStore
97 {
98 public bool downloadOnStart;
99 public List<ModelEntry> modelEntries;
100 public int debugMode;
101 public bool fullLlamaLib;
102 }
104
105 [DefaultExecutionOrder(-2)]
110 public class LLMManager
111 {
112 public static bool downloadOnStart = false;
113 public static List<ModelEntry> modelEntries = new List<ModelEntry>();
114 static List<LLM> llms = new List<LLM>();
115
116 public static float downloadProgress = 1;
117 public static List<Callback<float>> downloadProgressCallbacks = new List<Callback<float>>();
118 static Task<bool> SetupTask;
119 static readonly object lockObject = new object();
120 static long totalSize;
121 static long currFileSize;
122 static long completedSize;
123
128 public static void SetDownloadProgress(float progress)
129 {
130 downloadProgress = (completedSize + progress * currFileSize) / totalSize;
131 foreach (Callback<float> downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress);
132 }
133
138 public static Task<bool> Setup()
139 {
140 lock (lockObject)
141 {
142 if (SetupTask == null) SetupTask = SetupOnce();
143 }
144 return SetupTask;
145 }
146
151 public static async Task<bool> SetupOnce()
152 {
153 await LLMUnitySetup.AndroidExtractAsset(LLMUnitySetup.LLMManagerPath, true);
154 LoadFromDisk();
155
156 List<StringPair> downloads = new List<StringPair>();
157 foreach (ModelEntry modelEntry in modelEntries)
158 {
159 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
160 if (File.Exists(target)) continue;
161
162 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url))
163 {
164 await LLMUnitySetup.AndroidExtractFile(modelEntry.filename);
165 if (!File.Exists(target)) LLMUnitySetup.LogError($"Model {modelEntry.filename} could not be found!");
166 }
167 else
168 {
169 target = LLMUnitySetup.GetDownloadAssetPath(modelEntry.filename);
170 downloads.Add(new StringPair {source = modelEntry.url, target = target});
171 }
172 }
173 if (downloads.Count == 0) return true;
174
175 try
176 {
177 downloadProgress = 0;
178 totalSize = 0;
179 completedSize = 0;
180
182 Dictionary<string, long> fileSizes = new Dictionary<string, long>();
183 foreach (StringPair pair in downloads)
184 {
185 long size = client.GetURLFileSize(pair.source);
186 fileSizes[pair.source] = size;
187 totalSize += size;
188 }
189
190 foreach (StringPair pair in downloads)
191 {
192 currFileSize = fileSizes[pair.source];
193 await LLMUnitySetup.DownloadFile(pair.source, pair.target, false, null, SetDownloadProgress);
194 await LLMUnitySetup.AndroidExtractFile(Path.GetFileName(pair.target));
195 completedSize += currFileSize;
196 }
197
198 completedSize = totalSize;
200 }
201 catch (Exception ex)
202 {
203 LLMUnitySetup.LogError($"Error downloading the models: {ex.Message}");
204 return false;
205 }
206 return true;
207 }
208
214 public static void SetTemplate(string filename, string chatTemplate)
215 {
216 SetTemplate(Get(filename), chatTemplate);
217 }
218
224 public static void SetTemplate(ModelEntry entry, string chatTemplate)
225 {
226 if (entry == null) return;
227 entry.chatTemplate = chatTemplate;
228 foreach (LLM llm in llms)
229 {
230 if (llm != null && llm.model == entry.filename) llm.SetTemplate(chatTemplate);
231 }
232#if UNITY_EDITOR
233 Save();
234#endif
235 }
236
242 public static ModelEntry Get(string path)
243 {
244 string filename = Path.GetFileName(path);
245 string fullPath = LLMUnitySetup.GetFullPath(path);
246 foreach (ModelEntry entry in modelEntries)
247 {
248 if (entry.filename == filename || entry.path == fullPath) return entry;
249 }
250 return null;
251 }
252
258 public static string GetAssetPath(string filename)
259 {
260 ModelEntry entry = Get(filename);
261 if (entry == null) return "";
262#if UNITY_EDITOR
263 return entry.path;
264#else
265 return LLMUnitySetup.GetAssetPath(entry.filename);
266#endif
267 }
268
274 public static int Num(bool lora)
275 {
276 int num = 0;
277 foreach (ModelEntry entry in modelEntries)
278 {
279 if (entry.lora == lora) num++;
280 }
281 return num;
282 }
283
288 public static int NumModels()
289 {
290 return Num(false);
291 }
292
297 public static int NumLoras()
298 {
299 return Num(true);
300 }
301
306 public static void Register(LLM llm)
307 {
308 llms.Add(llm);
309 }
310
315 public static void Unregister(LLM llm)
316 {
317 llms.Remove(llm);
318 }
319
323 public static void LoadFromDisk()
324 {
325 if (!File.Exists(LLMUnitySetup.LLMManagerPath)) return;
326 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(File.ReadAllText(LLMUnitySetup.LLMManagerPath));
327 downloadOnStart = store.downloadOnStart;
328 modelEntries = store.modelEntries;
329 LLMUnitySetup.DebugMode = (LLMUnitySetup.DebugModeType)store.debugMode;
330 LLMUnitySetup.FullLlamaLib = store.fullLlamaLib;
331 }
332
333#if UNITY_EDITOR
334 static string LLMManagerPref = "LLMManager";
335
336 [HideInInspector] public static float modelProgress = 1;
337 [HideInInspector] public static float loraProgress = 1;
338
339 [InitializeOnLoadMethod]
340 static void InitializeOnLoad()
341 {
342 Load();
343 }
344
350 public static string AddEntry(ModelEntry entry)
351 {
352 int indexToInsert = modelEntries.Count;
353 if (!entry.lora)
354 {
355 if (modelEntries.Count > 0 && modelEntries[0].lora) indexToInsert = 0;
356 else
357 {
358 for (int i = modelEntries.Count - 1; i >= 0; i--)
359 {
360 if (!modelEntries[i].lora)
361 {
362 indexToInsert = i + 1;
363 break;
364 }
365 }
366 }
367 }
368 modelEntries.Insert(indexToInsert, entry);
369 Save();
370 return entry.filename;
371 }
372
381 public static string AddEntry(string path, bool lora = false, string label = null, string url = null)
382 {
383 return AddEntry(new ModelEntry(path, lora, label, url));
384 }
385
394 public static async Task<string> Download(string url, bool lora = false, bool log = false, string label = null)
395 {
396 foreach (ModelEntry entry in modelEntries)
397 {
398 if (entry.url == url)
399 {
400 if (log) LLMUnitySetup.Log($"Found existing entry for {url}");
401 return entry.filename;
402 }
403 }
404
405 string modelName = Path.GetFileName(url).Split("?")[0];
406 ModelEntry entryPath = Get(modelName);
407 if (entryPath != null)
408 {
409 if (log) LLMUnitySetup.Log($"Found existing entry for {modelName}");
410 return entryPath.filename;
411 }
412
413 string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName);
414 float preModelProgress = modelProgress;
415 float preLoraProgress = loraProgress;
416 try
417 {
418 if (!lora)
419 {
420 modelProgress = 0;
421 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress);
422 }
423 else
424 {
425 loraProgress = 0;
426 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress);
427 }
428 }
429 catch (Exception ex)
430 {
431 modelProgress = preModelProgress;
432 loraProgress = preLoraProgress;
433 LLMUnitySetup.LogError($"Error downloading the model from URL '{url}': " + ex.Message);
434 return null;
435 }
436 return AddEntry(modelPath, lora, label, url);
437 }
438
447 public static string Load(string path, bool lora = false, bool log = false, string label = null)
448 {
449 ModelEntry entry = Get(path);
450 if (entry != null)
451 {
452 if (log) LLMUnitySetup.Log($"Found existing entry for {entry.filename}");
453 return entry.filename;
454 }
455 return AddEntry(path, lora, label);
456 }
457
465 public static async Task<string> DownloadModel(string url, bool log = false, string label = null)
466 {
467 return await Download(url, false, log, label);
468 }
469
477 public static async Task<string> DownloadLora(string url, bool log = false, string label = null)
478 {
479 return await Download(url, true, log, label);
480 }
481
489 public static string LoadModel(string path, bool log = false, string label = null)
490 {
491 return Load(path, false, log, label);
492 }
493
501 public static string LoadLora(string path, bool log = false, string label = null)
502 {
503 return Load(path, true, log, label);
504 }
505
511 public static void SetURL(string filename, string url)
512 {
513 SetURL(Get(filename), url);
514 }
515
521 public static void SetURL(ModelEntry entry, string url)
522 {
523 if (entry == null) return;
524 entry.url = url;
525 Save();
526 }
527
533 public static void SetIncludeInBuild(string filename, bool includeInBuild)
534 {
535 SetIncludeInBuild(Get(filename), includeInBuild);
536 }
537
543 public static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
544 {
545 if (entry == null) return;
546 entry.includeInBuild = includeInBuild;
547 Save();
548 }
549
554 public static void SetDownloadOnStart(bool value)
555 {
556 downloadOnStart = value;
557 if (downloadOnStart)
558 {
559 bool warn = false;
560 foreach (ModelEntry entry in modelEntries)
561 {
562 if (entry.url == null || entry.url == "") warn = true;
563 }
564 if (warn) LLMUnitySetup.LogWarning("Some models do not have a URL and will be copied in the build. To resolve this fill in the URL field in the expanded view of the LLM Model list.");
565 }
566 Save();
567 }
568
573 public static void Remove(string filename)
574 {
575 Remove(Get(filename));
576 }
577
582 public static void Remove(ModelEntry entry)
583 {
584 if (entry == null) return;
585 modelEntries.Remove(entry);
586 Save();
587 foreach (LLM llm in llms)
588 {
589 if (!entry.lora && llm.model == entry.filename) llm.model = "";
590 else if (entry.lora) llm.RemoveLora(entry.filename);
591 }
592 }
593
598 public static void SetModelProgress(float progress)
599 {
600 modelProgress = progress;
601 }
602
607 public static void SetLoraProgress(float progress)
608 {
609 loraProgress = progress;
610 }
611
615 public static void Save()
616 {
617 string json = JsonUtility.ToJson(new LLMManagerStore
618 {
619 modelEntries = modelEntries,
620 downloadOnStart = downloadOnStart,
621 }, true);
622 PlayerPrefs.SetString(LLMManagerPref, json);
623 PlayerPrefs.Save();
624 }
625
629 public static void Load()
630 {
631 string pref = PlayerPrefs.GetString(LLMManagerPref);
632 if (pref == null || pref == "") return;
633 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(pref);
634 downloadOnStart = store.downloadOnStart;
635 modelEntries = store.modelEntries;
636 }
637
641 public static void SaveToDisk()
642 {
643 List<ModelEntry> modelEntriesBuild = new List<ModelEntry>();
644 foreach (ModelEntry modelEntry in modelEntries)
645 {
646 if (!modelEntry.includeInBuild) continue;
647 modelEntriesBuild.Add(modelEntry.OnlyRequiredFields());
648 }
649 string json = JsonUtility.ToJson(new LLMManagerStore
650 {
651 modelEntries = modelEntriesBuild,
652 downloadOnStart = downloadOnStart,
653 debugMode = (int)LLMUnitySetup.DebugMode,
654 fullLlamaLib = LLMUnitySetup.FullLlamaLib
655 }, true);
656 File.WriteAllText(LLMUnitySetup.LLMManagerPath, json);
657 }
658
663 public static void Build(ActionCallback copyCallback)
664 {
665 SaveToDisk();
666
667 foreach (ModelEntry modelEntry in modelEntries)
668 {
669 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
670 if (!modelEntry.includeInBuild || File.Exists(target)) continue;
671 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url)) copyCallback(modelEntry.path, target);
672 }
673 }
674
675#endif
676 }
677}
Class implementing the skeleton of a chat template.
static string FromGGUF(string path)
Determines the chat template name from a GGUF file. It reads the GGUF file and then determines the ch...
Class implementing the GGUF reader.
Definition LLMGGUF.cs:55
int GetIntField(string key)
Allows to retrieve an integer GGUF field.
Definition LLMGGUF.cs:157
string GetStringField(string key)
Allows to retrieve a string GGUF field.
Definition LLMGGUF.cs:145
Class implementing the LLM model manager.
static void SetTemplate(ModelEntry entry, string chatTemplate)
Sets the chat template for a model and distributes it to all LLMs using it.
static void SaveToDisk()
Saves the model manager to disk for the build.
static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
Sets whether to include a model to the build.
static void SetLoraProgress(float progress)
Sets the LORA download progress.
static void Unregister(LLM llm)
Removes a LLM from the model manager.
static async Task< string > Download(string url, bool lora=false, bool log=false, string label=null)
Downloads a model and adds a model entry to the model manager.
static string GetAssetPath(string filename)
Gets the asset path based on whether the application runs locally in the editor or in a build.
static int NumModels()
Returns the number of LLM models.
static ModelEntry Get(string path)
Gets the model entry for a model path.
static Task< bool > Setup()
Setup of the models.
static int Num(bool lora)
Returns the number of LLM/LORA models.
static void Register(LLM llm)
Registers a LLM to the model manager.
static void SetURL(string filename, string url)
Sets the URL for a model.
static void SetModelProgress(float progress)
Sets the LLM download progress.
static void Build(ActionCallback copyCallback)
Saves the model manager to disk along with models that are not (or can't) be downloaded for the build...
static void Save()
Serialises and saves the model manager.
static string AddEntry(string path, bool lora=false, string label=null, string url=null)
Creates and adds a model entry to the model manager.
static void Load()
Deserialises and loads the model manager.
static void SetDownloadOnStart(bool value)
Sets whether to download files on start.
static void SetIncludeInBuild(string filename, bool includeInBuild)
Sets whether to include a model to the build.
static int NumLoras()
Returns the number of LORA models.
static async Task< string > DownloadModel(string url, bool log=false, string label=null)
Downloads a LLM model from disk and adds a model entry to the model manager.
static string LoadModel(string path, bool log=false, string label=null)
Loads a LLM model from disk and adds a model entry to the model manager.
static void SetTemplate(string filename, string chatTemplate)
Sets the chat template for a model and distributes it to all LLMs using it.
static void Remove(string filename)
Removes a model from the model manager.
static void LoadFromDisk()
Loads the model manager from a file.
static void SetDownloadProgress(float progress)
Sets the model download progress in all registered callbacks.
static string Load(string path, bool lora=false, bool log=false, string label=null)
Loads a model from disk and adds a model entry to the model manager.
static string AddEntry(ModelEntry entry)
Adds a model entry to the model manager.
static async Task< bool > SetupOnce()
Task performing the setup of the models.
static async Task< string > DownloadLora(string url, bool log=false, string label=null)
Downloads a Lora model from disk and adds a model entry to the model manager.
static void Remove(ModelEntry entry)
Removes a model from the model manager.
static string LoadLora(string path, bool log=false, string label=null)
Loads a LORA model from disk and adds a model entry to the model manager.
static void SetURL(ModelEntry entry, string url)
Sets the URL for a model.
Class implementing helper functions for setup and process management.
static string LLMManagerPath
Path of file with build information for runtime.
static string modelDownloadPath
Model download path.
Class implementing the LLM server.
Definition LLM.cs:19
void RemoveLora(string path)
Allows to remove a LORA model from the LLM. Models supported are in .gguf format.
Definition LLM.cs:284
string model
LLM model to use (.gguf format)
Definition LLM.cs:62
void SetTemplate(string templateName, bool setDirty=true)
Set the chat template for the LLM.
Definition LLM.cs:337
Class implementing a LLM model entry.
Definition LLMManager.cs:18
ModelEntry(string path, bool lora=false, string label=null, string url=null)
Constructs a LLM model entry.
Definition LLMManager.cs:55
static string GetFilenameOrRelativeAssetPath(string path)
Returns the relative asset path if it is in the AssetPath folder (StreamingAssets or persistentPath),...
Definition LLMManager.cs:37
ModelEntry OnlyRequiredFields()
Returns only the required fields for bundling the model in the build.
Definition LLMManager.cs:85
Class implementing a resumable Web client.