LLM for Unity  v2.4.1
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMManager.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading.Tasks;
7using UnityEditor;
8using UnityEngine;
9
10namespace LLMUnity
11{
12 [Serializable]
17 public class ModelEntry
18 {
19 public string label;
20 public string filename;
21 public string path;
22 public bool lora;
23 public string chatTemplate;
24 public string url;
25 public bool embeddingOnly;
26 public int embeddingLength;
27 public bool includeInBuild;
28 public int contextLength;
29
30 static List<string> embeddingOnlyArchs = new List<string> {"bert", "nomic-bert", "jina-bert-v2", "t5", "t5encoder"};
31
37 public static string GetFilenameOrRelativeAssetPath(string path)
38 {
39 string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
40 string basePath = LLMUnitySetup.GetAssetPath();
41 if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath))
42 {
43 return LLMUnitySetup.RelativePath(assetPath, basePath);
44 }
45 return Path.GetFileName(path);
46 }
47
55 public ModelEntry(string path, bool lora = false, string label = null, string url = null)
56 {
57 filename = GetFilenameOrRelativeAssetPath(path);
58 this.label = label == null ? Path.GetFileName(filename) : label;
59 this.lora = lora;
60 this.path = LLMUnitySetup.GetFullPath(path);
61 this.url = url;
62 includeInBuild = true;
63 chatTemplate = null;
64 contextLength = -1;
65 embeddingOnly = false;
66 embeddingLength = 0;
67 if (!lora)
68 {
69 GGUFReader reader = new GGUFReader(this.path);
70 string arch = reader.GetStringField("general.architecture");
71 if (arch != null)
72 {
73 contextLength = reader.GetIntField($"{arch}.context_length");
74 embeddingLength = reader.GetIntField($"{arch}.embedding_length");
75 }
76 embeddingOnly = embeddingOnlyArchs.Contains(arch);
77 chatTemplate = embeddingOnly ? default : ChatTemplate.FromGGUF(reader, this.path);
78 }
79 }
80
86 {
87 ModelEntry entry = (ModelEntry)MemberwiseClone();
88 entry.label = null;
89 entry.path = entry.filename;
90 return entry;
91 }
92 }
93
95 [Serializable]
96 public class LLMManagerStore
97 {
98 public bool downloadOnStart;
99 public List<ModelEntry> modelEntries;
100 }
102
103 [DefaultExecutionOrder(-2)]
108 public class LLMManager
109 {
110 public static bool downloadOnStart = false;
111 public static List<ModelEntry> modelEntries = new List<ModelEntry>();
112 static List<LLM> llms = new List<LLM>();
113
114 public static float downloadProgress = 1;
115 public static List<Callback<float>> downloadProgressCallbacks = new List<Callback<float>>();
116 static Task<bool> SetupTask;
117 static readonly object lockObject = new object();
118 static long totalSize;
119 static long currFileSize;
120 static long completedSize;
121
126 public static void SetDownloadProgress(float progress)
127 {
128 downloadProgress = (completedSize + progress * currFileSize) / totalSize;
129 foreach (Callback<float> downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress);
130 }
131
136 public static Task<bool> Setup()
137 {
138 lock (lockObject)
139 {
140 if (SetupTask == null) SetupTask = SetupOnce();
141 }
142 return SetupTask;
143 }
144
149 public static async Task<bool> SetupOnce()
150 {
151 await LLMUnitySetup.AndroidExtractAsset(LLMUnitySetup.LLMManagerPath, true);
152 LoadFromDisk();
153
154 List<StringPair> downloads = new List<StringPair>();
155 foreach (ModelEntry modelEntry in modelEntries)
156 {
157 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
158 if (File.Exists(target)) continue;
159
160 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url))
161 {
162 await LLMUnitySetup.AndroidExtractFile(modelEntry.filename);
163 if (!File.Exists(target)) LLMUnitySetup.LogError($"Model {modelEntry.filename} could not be found!");
164 }
165 else
166 {
167 target = LLMUnitySetup.GetDownloadAssetPath(modelEntry.filename);
168 downloads.Add(new StringPair {source = modelEntry.url, target = target});
169 }
170 }
171 if (downloads.Count == 0) return true;
172
173 try
174 {
175 downloadProgress = 0;
176 totalSize = 0;
177 completedSize = 0;
178
180 Dictionary<string, long> fileSizes = new Dictionary<string, long>();
181 foreach (StringPair pair in downloads)
182 {
183 long size = client.GetURLFileSize(pair.source);
184 fileSizes[pair.source] = size;
185 totalSize += size;
186 }
187
188 foreach (StringPair pair in downloads)
189 {
190 currFileSize = fileSizes[pair.source];
191 await LLMUnitySetup.DownloadFile(pair.source, pair.target, false, null, SetDownloadProgress);
192 await LLMUnitySetup.AndroidExtractFile(Path.GetFileName(pair.target));
193 completedSize += currFileSize;
194 }
195
196 completedSize = totalSize;
198 }
199 catch (Exception ex)
200 {
201 LLMUnitySetup.LogError($"Error downloading the models: {ex.Message}");
202 return false;
203 }
204 return true;
205 }
206
212 public static void SetTemplate(string filename, string chatTemplate)
213 {
214 SetTemplate(Get(filename), chatTemplate);
215 }
216
222 public static void SetTemplate(ModelEntry entry, string chatTemplate)
223 {
224 if (entry == null) return;
225 entry.chatTemplate = chatTemplate;
226 foreach (LLM llm in llms)
227 {
228 if (llm != null && llm.model == entry.filename) llm.SetTemplate(chatTemplate);
229 }
230#if UNITY_EDITOR
231 Save();
232#endif
233 }
234
240 public static ModelEntry Get(string path)
241 {
242 string filename = Path.GetFileName(path);
243 string fullPath = LLMUnitySetup.GetFullPath(path);
244 foreach (ModelEntry entry in modelEntries)
245 {
246 if (entry.filename == filename || entry.path == fullPath) return entry;
247 }
248 return null;
249 }
250
256 public static string GetAssetPath(string filename)
257 {
258 ModelEntry entry = Get(filename);
259 if (entry == null) return "";
260#if UNITY_EDITOR
261 return entry.path;
262#else
263 return LLMUnitySetup.GetAssetPath(entry.filename);
264#endif
265 }
266
272 public static int Num(bool lora)
273 {
274 int num = 0;
275 foreach (ModelEntry entry in modelEntries)
276 {
277 if (entry.lora == lora) num++;
278 }
279 return num;
280 }
281
286 public static int NumModels()
287 {
288 return Num(false);
289 }
290
295 public static int NumLoras()
296 {
297 return Num(true);
298 }
299
304 public static void Register(LLM llm)
305 {
306 llms.Add(llm);
307 }
308
313 public static void Unregister(LLM llm)
314 {
315 llms.Remove(llm);
316 }
317
321 public static void LoadFromDisk()
322 {
323 if (!File.Exists(LLMUnitySetup.LLMManagerPath)) return;
324 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(File.ReadAllText(LLMUnitySetup.LLMManagerPath));
325 downloadOnStart = store.downloadOnStart;
326 modelEntries = store.modelEntries;
327 }
328
329#if UNITY_EDITOR
330 static string LLMManagerPref = "LLMManager";
331
332 [HideInInspector] public static float modelProgress = 1;
333 [HideInInspector] public static float loraProgress = 1;
334
335 [InitializeOnLoadMethod]
336 static void InitializeOnLoad()
337 {
338 Load();
339 }
340
346 public static string AddEntry(ModelEntry entry)
347 {
348 int indexToInsert = modelEntries.Count;
349 if (!entry.lora)
350 {
351 if (modelEntries.Count > 0 && modelEntries[0].lora) indexToInsert = 0;
352 else
353 {
354 for (int i = modelEntries.Count - 1; i >= 0; i--)
355 {
356 if (!modelEntries[i].lora)
357 {
358 indexToInsert = i + 1;
359 break;
360 }
361 }
362 }
363 }
364 modelEntries.Insert(indexToInsert, entry);
365 Save();
366 return entry.filename;
367 }
368
377 public static string AddEntry(string path, bool lora = false, string label = null, string url = null)
378 {
379 return AddEntry(new ModelEntry(path, lora, label, url));
380 }
381
390 public static async Task<string> Download(string url, bool lora = false, bool log = false, string label = null)
391 {
392 foreach (ModelEntry entry in modelEntries)
393 {
394 if (entry.url == url)
395 {
396 if (log) LLMUnitySetup.Log($"Found existing entry for {url}");
397 return entry.filename;
398 }
399 }
400
401 string modelName = Path.GetFileName(url).Split("?")[0];
402 ModelEntry entryPath = Get(modelName);
403 if (entryPath != null)
404 {
405 if (log) LLMUnitySetup.Log($"Found existing entry for {modelName}");
406 return entryPath.filename;
407 }
408
409 string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName);
410 float preModelProgress = modelProgress;
411 float preLoraProgress = loraProgress;
412 try
413 {
414 if (!lora)
415 {
416 modelProgress = 0;
417 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress);
418 }
419 else
420 {
421 loraProgress = 0;
422 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress);
423 }
424 }
425 catch (Exception ex)
426 {
427 modelProgress = preModelProgress;
428 loraProgress = preLoraProgress;
429 LLMUnitySetup.LogError($"Error downloading the model from URL '{url}': " + ex.Message);
430 return null;
431 }
432 return AddEntry(modelPath, lora, label, url);
433 }
434
443 public static string Load(string path, bool lora = false, bool log = false, string label = null)
444 {
445 ModelEntry entry = Get(path);
446 if (entry != null)
447 {
448 if (log) LLMUnitySetup.Log($"Found existing entry for {entry.filename}");
449 return entry.filename;
450 }
451 return AddEntry(path, lora, label);
452 }
453
461 public static async Task<string> DownloadModel(string url, bool log = false, string label = null)
462 {
463 return await Download(url, false, log, label);
464 }
465
473 public static async Task<string> DownloadLora(string url, bool log = false, string label = null)
474 {
475 return await Download(url, true, log, label);
476 }
477
485 public static string LoadModel(string path, bool log = false, string label = null)
486 {
487 return Load(path, false, log, label);
488 }
489
497 public static string LoadLora(string path, bool log = false, string label = null)
498 {
499 return Load(path, true, log, label);
500 }
501
507 public static void SetURL(string filename, string url)
508 {
509 SetURL(Get(filename), url);
510 }
511
517 public static void SetURL(ModelEntry entry, string url)
518 {
519 if (entry == null) return;
520 entry.url = url;
521 Save();
522 }
523
529 public static void SetIncludeInBuild(string filename, bool includeInBuild)
530 {
531 SetIncludeInBuild(Get(filename), includeInBuild);
532 }
533
539 public static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
540 {
541 if (entry == null) return;
542 entry.includeInBuild = includeInBuild;
543 Save();
544 }
545
550 public static void SetDownloadOnStart(bool value)
551 {
552 downloadOnStart = value;
553 if (downloadOnStart)
554 {
555 bool warn = false;
556 foreach (ModelEntry entry in modelEntries)
557 {
558 if (entry.url == null || entry.url == "") warn = true;
559 }
560 if (warn) LLMUnitySetup.LogWarning("Some models do not have a URL and will be copied in the build. To resolve this fill in the URL field in the expanded view of the LLM Model list.");
561 }
562 Save();
563 }
564
569 public static void Remove(string filename)
570 {
571 Remove(Get(filename));
572 }
573
578 public static void Remove(ModelEntry entry)
579 {
580 if (entry == null) return;
581 modelEntries.Remove(entry);
582 Save();
583 foreach (LLM llm in llms)
584 {
585 if (!entry.lora && llm.model == entry.filename) llm.model = "";
586 else if (entry.lora) llm.RemoveLora(entry.filename);
587 }
588 }
589
594 public static void SetModelProgress(float progress)
595 {
596 modelProgress = progress;
597 }
598
603 public static void SetLoraProgress(float progress)
604 {
605 loraProgress = progress;
606 }
607
611 public static void Save()
612 {
613 string json = JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnStart = downloadOnStart }, true);
614 PlayerPrefs.SetString(LLMManagerPref, json);
615 PlayerPrefs.Save();
616 }
617
621 public static void Load()
622 {
623 string pref = PlayerPrefs.GetString(LLMManagerPref);
624 if (pref == null || pref == "") return;
625 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(pref);
626 downloadOnStart = store.downloadOnStart;
627 modelEntries = store.modelEntries;
628 }
629
633 public static void SaveToDisk()
634 {
635 List<ModelEntry> modelEntriesBuild = new List<ModelEntry>();
636 foreach (ModelEntry modelEntry in modelEntries)
637 {
638 if (!modelEntry.includeInBuild) continue;
639 modelEntriesBuild.Add(modelEntry.OnlyRequiredFields());
640 }
641 string json = JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntriesBuild, downloadOnStart = downloadOnStart }, true);
642 File.WriteAllText(LLMUnitySetup.LLMManagerPath, json);
643 }
644
649 public static void Build(ActionCallback copyCallback)
650 {
651 SaveToDisk();
652
653 foreach (ModelEntry modelEntry in modelEntries)
654 {
655 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
656 if (!modelEntry.includeInBuild || File.Exists(target)) continue;
657 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url)) copyCallback(modelEntry.path, target);
658 }
659 }
660
661#endif
662 }
663}
Class implementing the skeleton of a chat template.
static string FromGGUF(string path)
Determines the chat template name from a GGUF file. It reads the GGUF file and then determines the ch...
Class implementing the GGUF reader.
Definition LLMGGUF.cs:55
int GetIntField(string key)
Allows to retrieve an integer GGUF field.
Definition LLMGGUF.cs:157
string GetStringField(string key)
Allows to retrieve a string GGUF field.
Definition LLMGGUF.cs:145
Class implementing the LLM model manager.
static void SetTemplate(ModelEntry entry, string chatTemplate)
Sets the chat template for a model and distributes it to all LLMs using it.
static void SaveToDisk()
Saves the model manager to disk for the build.
static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
Sets whether to include a model to the build.
static void SetLoraProgress(float progress)
Sets the LORA download progress.
static void Unregister(LLM llm)
Removes a LLM from the model manager.
static async Task< string > Download(string url, bool lora=false, bool log=false, string label=null)
Downloads a model and adds a model entry to the model manager.
static string GetAssetPath(string filename)
Gets the asset path based on whether the application runs locally in the editor or in a build.
static int NumModels()
Returns the number of LLM models.
static ModelEntry Get(string path)
Gets the model entry for a model path.
static Task< bool > Setup()
Setup of the models.
static int Num(bool lora)
Returns the number of LLM/LORA models.
static void Register(LLM llm)
Registers a LLM to the model manager.
static void SetURL(string filename, string url)
Sets the URL for a model.
static void SetModelProgress(float progress)
Sets the LLM download progress.
static void Build(ActionCallback copyCallback)
Saves the model manager to disk along with models that are not (or can't) be downloaded for the build...
static void Save()
Serialises and saves the model manager.
static string AddEntry(string path, bool lora=false, string label=null, string url=null)
Creates and adds a model entry to the model manager.
static void Load()
Deserialises and loads the model manager.
static void SetDownloadOnStart(bool value)
Sets whether to download files on start.
static void SetIncludeInBuild(string filename, bool includeInBuild)
Sets whether to include a model to the build.
static int NumLoras()
Returns the number of LORA models.
static async Task< string > DownloadModel(string url, bool log=false, string label=null)
Downloads a LLM model from disk and adds a model entry to the model manager.
static string LoadModel(string path, bool log=false, string label=null)
Loads a LLM model from disk and adds a model entry to the model manager.
static void SetTemplate(string filename, string chatTemplate)
Sets the chat template for a model and distributes it to all LLMs using it.
static void Remove(string filename)
Removes a model from the model manager.
static void LoadFromDisk()
Loads the model manager from a file.
static void SetDownloadProgress(float progress)
Sets the model download progress in all registered callbacks.
static string Load(string path, bool lora=false, bool log=false, string label=null)
Loads a model from disk and adds a model entry to the model manager.
static string AddEntry(ModelEntry entry)
Adds a model entry to the model manager.
static async Task< bool > SetupOnce()
Task performing the setup of the models.
static async Task< string > DownloadLora(string url, bool log=false, string label=null)
Downloads a Lora model from disk and adds a model entry to the model manager.
static void Remove(ModelEntry entry)
Removes a model from the model manager.
static string LoadLora(string path, bool log=false, string label=null)
Loads a LORA model from disk and adds a model entry to the model manager.
static void SetURL(ModelEntry entry, string url)
Sets the URL for a model.
Class implementing helper functions for setup and process management.
static string LLMManagerPath
Path of file with build information for runtime.
static string modelDownloadPath
Model download path.
Class implementing the LLM server.
Definition LLM.cs:19
void RemoveLora(string path)
Allows to remove a LORA model from the LLM. Models supported are in .gguf format.
Definition LLM.cs:272
string model
the LLM model to use. Models with .gguf format are allowed.
Definition LLM.cs:54
void SetTemplate(string templateName, bool setDirty=true)
Set the chat template for the LLM.
Definition LLM.cs:325
Class implementing a LLM model entry.
Definition LLMManager.cs:18
ModelEntry(string path, bool lora=false, string label=null, string url=null)
Constructs a LLM model entry.
Definition LLMManager.cs:55
static string GetFilenameOrRelativeAssetPath(string path)
Returns the relative asset path if it is in the AssetPath folder (StreamingAssets or persistentPath),...
Definition LLMManager.cs:37
ModelEntry OnlyRequiredFields()
Returns only the required fields for bundling the model in the build.
Definition LLMManager.cs:85
Class implementing a resumable Web client.