LLM for Unity  v2.3.0
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMManager.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading.Tasks;
7using UnityEditor;
8using UnityEngine;
9
10namespace LLMUnity
11{
12 [Serializable]
17 public class ModelEntry
18 {
19 public string label;
20 public string filename;
21 public string path;
22 public bool lora;
23 public string chatTemplate;
24 public string url;
25 public bool embeddingOnly;
26 public int embeddingLength;
27 public bool includeInBuild;
28 public int contextLength;
29
30 static List<string> embeddingOnlyArchs = new List<string> {"bert", "nomic-bert", "jina-bert-v2", "t5", "t5encoder"};
31
37 public static string GetFilenameOrRelativeAssetPath(string path)
38 {
39 string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
40 string basePath = LLMUnitySetup.GetAssetPath();
41 if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath))
42 {
43 return LLMUnitySetup.RelativePath(assetPath, basePath);
44 }
45 return Path.GetFileName(path);
46 }
47
55 public ModelEntry(string path, bool lora = false, string label = null, string url = null)
56 {
57 filename = GetFilenameOrRelativeAssetPath(path);
58 this.label = label == null ? Path.GetFileName(filename) : label;
59 this.lora = lora;
60 this.path = LLMUnitySetup.GetFullPath(path);
61 this.url = url;
62 includeInBuild = true;
63 chatTemplate = null;
64 contextLength = -1;
65 embeddingOnly = false;
66 embeddingLength = 0;
67 if (!lora)
68 {
69 GGUFReader reader = new GGUFReader(this.path);
70 string arch = reader.GetStringField("general.architecture");
71 if (arch != null)
72 {
73 contextLength = reader.GetIntField($"{arch}.context_length");
74 embeddingLength = reader.GetIntField($"{arch}.embedding_length");
75 }
76 embeddingOnly = embeddingOnlyArchs.Contains(arch);
77 chatTemplate = embeddingOnly ? default : ChatTemplate.FromGGUF(reader, this.path);
78 }
79 }
80
86 {
87 ModelEntry entry = (ModelEntry)MemberwiseClone();
88 entry.label = null;
89 entry.path = entry.filename;
90 return entry;
91 }
92 }
93
95 [Serializable]
96 public class LLMManagerStore
97 {
98 public bool downloadOnStart;
99 public List<ModelEntry> modelEntries;
100 }
102
103 [DefaultExecutionOrder(-2)]
108 public class LLMManager
109 {
110 public static bool downloadOnStart = false;
111 public static List<ModelEntry> modelEntries = new List<ModelEntry>();
112 static List<LLM> llms = new List<LLM>();
113
114 public static float downloadProgress = 1;
115 public static List<Callback<float>> downloadProgressCallbacks = new List<Callback<float>>();
116 static Task<bool> SetupTask;
117 static readonly object lockObject = new object();
118 static long totalSize;
119 static long currFileSize;
120 static long completedSize;
121
126 public static void SetDownloadProgress(float progress)
127 {
128 downloadProgress = (completedSize + progress * currFileSize) / totalSize;
129 foreach (Callback<float> downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress);
130 }
131
136 public static Task<bool> Setup()
137 {
138 lock (lockObject)
139 {
140 if (SetupTask == null) SetupTask = SetupOnce();
141 }
142 return SetupTask;
143 }
144
149 public static async Task<bool> SetupOnce()
150 {
151 await LLMUnitySetup.AndroidExtractAsset(LLMUnitySetup.LLMManagerPath, true);
152 LoadFromDisk();
153
154 List<StringPair> downloads = new List<StringPair>();
155 foreach (ModelEntry modelEntry in modelEntries)
156 {
157 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
158 if (File.Exists(target)) continue;
159
160 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url))
161 {
162 await LLMUnitySetup.AndroidExtractFile(modelEntry.filename);
163 if (!File.Exists(target)) LLMUnitySetup.LogError($"Model {modelEntry.filename} could not be found!");
164 }
165 else
166 {
167 downloads.Add(new StringPair {source = modelEntry.url, target = target});
168 }
169 }
170 if (downloads.Count == 0) return true;
171
172 try
173 {
174 downloadProgress = 0;
175 totalSize = 0;
176 completedSize = 0;
177
179 Dictionary<string, long> fileSizes = new Dictionary<string, long>();
180 foreach (StringPair pair in downloads)
181 {
182 long size = client.GetURLFileSize(pair.source);
183 fileSizes[pair.source] = size;
184 totalSize += size;
185 }
186
187 foreach (StringPair pair in downloads)
188 {
189 currFileSize = fileSizes[pair.source];
190 await LLMUnitySetup.DownloadFile(pair.source, pair.target, false, null, SetDownloadProgress);
191 await LLMUnitySetup.AndroidExtractFile(Path.GetFileName(pair.target));
192 completedSize += currFileSize;
193 }
194
195 completedSize = totalSize;
197 }
198 catch (Exception ex)
199 {
200 LLMUnitySetup.LogError($"Error downloading the models: {ex.Message}");
201 return false;
202 }
203 return true;
204 }
205
211 public static void SetTemplate(string filename, string chatTemplate)
212 {
213 SetTemplate(Get(filename), chatTemplate);
214 }
215
221 public static void SetTemplate(ModelEntry entry, string chatTemplate)
222 {
223 if (entry == null) return;
224 entry.chatTemplate = chatTemplate;
225 foreach (LLM llm in llms)
226 {
227 if (llm != null && llm.model == entry.filename) llm.SetTemplate(chatTemplate);
228 }
229#if UNITY_EDITOR
230 Save();
231#endif
232 }
233
239 public static ModelEntry Get(string path)
240 {
241 string filename = Path.GetFileName(path);
242 string fullPath = LLMUnitySetup.GetFullPath(path);
243 foreach (ModelEntry entry in modelEntries)
244 {
245 if (entry.filename == filename || entry.path == fullPath) return entry;
246 }
247 return null;
248 }
249
255 public static string GetAssetPath(string filename)
256 {
257 ModelEntry entry = Get(filename);
258 if (entry == null) return "";
259#if UNITY_EDITOR
260 return entry.path;
261#else
262 return LLMUnitySetup.GetAssetPath(entry.filename);
263#endif
264 }
265
271 public static int Num(bool lora)
272 {
273 int num = 0;
274 foreach (ModelEntry entry in modelEntries)
275 {
276 if (entry.lora == lora) num++;
277 }
278 return num;
279 }
280
285 public static int NumModels()
286 {
287 return Num(false);
288 }
289
294 public static int NumLoras()
295 {
296 return Num(true);
297 }
298
303 public static void Register(LLM llm)
304 {
305 llms.Add(llm);
306 }
307
312 public static void Unregister(LLM llm)
313 {
314 llms.Remove(llm);
315 }
316
320 public static void LoadFromDisk()
321 {
322 if (!File.Exists(LLMUnitySetup.LLMManagerPath)) return;
323 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(File.ReadAllText(LLMUnitySetup.LLMManagerPath));
324 downloadOnStart = store.downloadOnStart;
325 modelEntries = store.modelEntries;
326 }
327
328#if UNITY_EDITOR
329 static string LLMManagerPref = "LLMManager";
330
331 [HideInInspector] public static float modelProgress = 1;
332 [HideInInspector] public static float loraProgress = 1;
333
334 [InitializeOnLoadMethod]
335 static void InitializeOnLoad()
336 {
337 Load();
338 }
339
345 public static string AddEntry(ModelEntry entry)
346 {
347 int indexToInsert = modelEntries.Count;
348 if (!entry.lora)
349 {
350 if (modelEntries.Count > 0 && modelEntries[0].lora) indexToInsert = 0;
351 else
352 {
353 for (int i = modelEntries.Count - 1; i >= 0; i--)
354 {
355 if (!modelEntries[i].lora)
356 {
357 indexToInsert = i + 1;
358 break;
359 }
360 }
361 }
362 }
363 modelEntries.Insert(indexToInsert, entry);
364 Save();
365 return entry.filename;
366 }
367
376 public static string AddEntry(string path, bool lora = false, string label = null, string url = null)
377 {
378 return AddEntry(new ModelEntry(path, lora, label, url));
379 }
380
389 public static async Task<string> Download(string url, bool lora = false, bool log = false, string label = null)
390 {
391 foreach (ModelEntry entry in modelEntries)
392 {
393 if (entry.url == url)
394 {
395 if (log) LLMUnitySetup.Log($"Found existing entry for {url}");
396 return entry.filename;
397 }
398 }
399
400 string modelName = Path.GetFileName(url).Split("?")[0];
401 ModelEntry entryPath = Get(modelName);
402 if (entryPath != null)
403 {
404 if (log) LLMUnitySetup.Log($"Found existing entry for {modelName}");
405 return entryPath.filename;
406 }
407
408 string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName);
409 float preModelProgress = modelProgress;
410 float preLoraProgress = loraProgress;
411 try
412 {
413 if (!lora)
414 {
415 modelProgress = 0;
416 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress);
417 }
418 else
419 {
420 loraProgress = 0;
421 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress);
422 }
423 }
424 catch (Exception ex)
425 {
426 modelProgress = preModelProgress;
427 loraProgress = preLoraProgress;
428 LLMUnitySetup.LogError($"Error downloading the model from URL '{url}': " + ex.Message);
429 return null;
430 }
431 return AddEntry(modelPath, lora, label, url);
432 }
433
442 public static string Load(string path, bool lora = false, bool log = false, string label = null)
443 {
444 ModelEntry entry = Get(path);
445 if (entry != null)
446 {
447 if (log) LLMUnitySetup.Log($"Found existing entry for {entry.filename}");
448 return entry.filename;
449 }
450 return AddEntry(path, lora, label);
451 }
452
460 public static async Task<string> DownloadModel(string url, bool log = false, string label = null)
461 {
462 return await Download(url, false, log, label);
463 }
464
472 public static async Task<string> DownloadLora(string url, bool log = false, string label = null)
473 {
474 return await Download(url, true, log, label);
475 }
476
484 public static string LoadModel(string path, bool log = false, string label = null)
485 {
486 return Load(path, false, log, label);
487 }
488
496 public static string LoadLora(string path, bool log = false, string label = null)
497 {
498 return Load(path, true, log, label);
499 }
500
506 public static void SetURL(string filename, string url)
507 {
508 SetURL(Get(filename), url);
509 }
510
516 public static void SetURL(ModelEntry entry, string url)
517 {
518 if (entry == null) return;
519 entry.url = url;
520 Save();
521 }
522
528 public static void SetIncludeInBuild(string filename, bool includeInBuild)
529 {
530 SetIncludeInBuild(Get(filename), includeInBuild);
531 }
532
538 public static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
539 {
540 if (entry == null) return;
541 entry.includeInBuild = includeInBuild;
542 Save();
543 }
544
549 public static void SetDownloadOnStart(bool value)
550 {
551 downloadOnStart = value;
552 if (downloadOnStart)
553 {
554 bool warn = false;
555 foreach (ModelEntry entry in modelEntries)
556 {
557 if (entry.url == null || entry.url == "") warn = true;
558 }
559 if (warn) LLMUnitySetup.LogWarning("Some models do not have a URL and will be copied in the build. To resolve this fill in the URL field in the expanded view of the LLM Model list.");
560 }
561 Save();
562 }
563
568 public static void Remove(string filename)
569 {
570 Remove(Get(filename));
571 }
572
577 public static void Remove(ModelEntry entry)
578 {
579 if (entry == null) return;
580 modelEntries.Remove(entry);
581 Save();
582 foreach (LLM llm in llms)
583 {
584 if (!entry.lora && llm.model == entry.filename) llm.model = "";
585 else if (entry.lora) llm.RemoveLora(entry.filename);
586 }
587 }
588
593 public static void SetModelProgress(float progress)
594 {
595 modelProgress = progress;
596 }
597
602 public static void SetLoraProgress(float progress)
603 {
604 loraProgress = progress;
605 }
606
610 public static void Save()
611 {
612 string json = JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnStart = downloadOnStart }, true);
613 PlayerPrefs.SetString(LLMManagerPref, json);
614 PlayerPrefs.Save();
615 }
616
620 public static void Load()
621 {
622 string pref = PlayerPrefs.GetString(LLMManagerPref);
623 if (pref == null || pref == "") return;
624 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(pref);
625 downloadOnStart = store.downloadOnStart;
626 modelEntries = store.modelEntries;
627 }
628
632 public static void SaveToDisk()
633 {
634 List<ModelEntry> modelEntriesBuild = new List<ModelEntry>();
635 foreach (ModelEntry modelEntry in modelEntries)
636 {
637 if (!modelEntry.includeInBuild) continue;
638 modelEntriesBuild.Add(modelEntry.OnlyRequiredFields());
639 }
640 string json = JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntriesBuild, downloadOnStart = downloadOnStart }, true);
641 File.WriteAllText(LLMUnitySetup.LLMManagerPath, json);
642 }
643
648 public static void Build(ActionCallback copyCallback)
649 {
650 SaveToDisk();
651
652 foreach (ModelEntry modelEntry in modelEntries)
653 {
654 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
655 if (!modelEntry.includeInBuild || File.Exists(target)) continue;
656 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url)) copyCallback(modelEntry.path, target);
657 }
658 }
659
660#endif
661 }
662}
Class implementing the skeleton of a chat template.
static string FromGGUF(string path)
Determines the chat template name from a GGUF file. It reads the GGUF file and then determines the ch...
Class implementing the GGUF reader.
Definition LLMGGUF.cs:55
int GetIntField(string key)
Allows to retrieve an integer GGUF field.
Definition LLMGGUF.cs:157
string GetStringField(string key)
Allows to retrieve a string GGUF field.
Definition LLMGGUF.cs:145
Class implementing the LLM model manager.
static void SetTemplate(ModelEntry entry, string chatTemplate)
Sets the chat template for a model and distributes it to all LLMs using it.
static void SaveToDisk()
Saves the model manager to disk for the build.
static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
Sets whether to include a model to the build.
static void SetLoraProgress(float progress)
Sets the LORA download progress.
static void Unregister(LLM llm)
Removes a LLM from the model manager.
static async Task< string > Download(string url, bool lora=false, bool log=false, string label=null)
Downloads a model and adds a model entry to the model manager.
static string GetAssetPath(string filename)
Gets the asset path based on whether the application runs locally in the editor or in a build.
static int NumModels()
Returns the number of LLM models.
static ModelEntry Get(string path)
Gets the model entry for a model path.
static Task< bool > Setup()
Setup of the models.
static int Num(bool lora)
Returns the number of LLM/LORA models.
static void Register(LLM llm)
Registers a LLM to the model manager.
static void SetURL(string filename, string url)
Sets the URL for a model.
static void SetModelProgress(float progress)
Sets the LLM download progress.
static void Build(ActionCallback copyCallback)
Saves the model manager to disk along with models that are not (or can't) be downloaded for the build...
static void Save()
Serialises and saves the model manager.
static string AddEntry(string path, bool lora=false, string label=null, string url=null)
Creates and adds a model entry to the model manager.
static void Load()
Deserialises and loads the model manager.
static void SetDownloadOnStart(bool value)
Sets whether to download files on start.
static void SetIncludeInBuild(string filename, bool includeInBuild)
Sets whether to include a model to the build.
static int NumLoras()
Returns the number of LORA models.
static async Task< string > DownloadModel(string url, bool log=false, string label=null)
Downloads a LLM model from disk and adds a model entry to the model manager.
static string LoadModel(string path, bool log=false, string label=null)
Loads a LLM model from disk and adds a model entry to the model manager.
static void SetTemplate(string filename, string chatTemplate)
Sets the chat template for a model and distributes it to all LLMs using it.
static void Remove(string filename)
Removes a model from the model manager.
static void LoadFromDisk()
Loads the model manager from a file.
static void SetDownloadProgress(float progress)
Sets the model download progress in all registered callbacks.
static string Load(string path, bool lora=false, bool log=false, string label=null)
Loads a model from disk and adds a model entry to the model manager.
static string AddEntry(ModelEntry entry)
Adds a model entry to the model manager.
static async Task< bool > SetupOnce()
Task performing the setup of the models.
static async Task< string > DownloadLora(string url, bool log=false, string label=null)
Downloads a Lora model from disk and adds a model entry to the model manager.
static void Remove(ModelEntry entry)
Removes a model from the model manager.
static string LoadLora(string path, bool log=false, string label=null)
Loads a LORA model from disk and adds a model entry to the model manager.
static void SetURL(ModelEntry entry, string url)
Sets the URL for a model.
Class implementing helper functions for setup and process management.
static string LLMManagerPath
Path of file with build information for runtime.
static string modelDownloadPath
Model download path.
Class implementing the LLM server.
Definition LLM.cs:19
void RemoveLora(string path)
Allows to remove a LORA model from the LLM. Models supported are in .gguf format.
Definition LLM.cs:272
string model
the LLM model to use. Models with .gguf format are allowed.
Definition LLM.cs:56
void SetTemplate(string templateName, bool setDirty=true)
Set the chat template for the LLM.
Definition LLM.cs:325
Class implementing a LLM model entry.
Definition LLMManager.cs:18
ModelEntry(string path, bool lora=false, string label=null, string url=null)
Constructs a LLM model entry.
Definition LLMManager.cs:55
static string GetFilenameOrRelativeAssetPath(string path)
Returns the relative asset path if it is in the AssetPath folder (StreamingAssets or persistentPath),...
Definition LLMManager.cs:37
ModelEntry OnlyRequiredFields()
Returns only the required fields for bundling the model in the build.
Definition LLMManager.cs:85
Class implementing a resumable Web client.