LLM for Unity  v3.0.1
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMManager.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading.Tasks;
7using UnityEditor;
8using UnityEngine;
9
10namespace LLMUnity
11{
12 [Serializable]
17 public class ModelEntry
18 {
19 public string label;
20 public string filename;
21 public string path;
22 public bool lora;
23 public string url;
24 public bool embeddingOnly;
25 public int embeddingLength;
26 public bool includeInBuild;
27 public int contextLength;
28
29 static List<string> embeddingOnlyArchs = new List<string> { "bert", "nomic-bert", "jina-bert-v2", "t5", "t5encoder", "gemma-embedding" };
30
36 public static string GetFilenameOrRelativeAssetPath(string path)
37 {
38 string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
39 string basePath = LLMUnitySetup.GetAssetPath();
40 if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath))
41 {
42 return LLMUnitySetup.RelativePath(assetPath, basePath);
43 }
44 return Path.GetFileName(path);
45 }
46
54 public ModelEntry(string path, bool lora = false, string label = null, string url = null)
55 {
56 filename = GetFilenameOrRelativeAssetPath(path);
57 this.label = label == null ? Path.GetFileName(filename) : label;
58 this.lora = lora;
59 this.path = LLMUnitySetup.GetFullPath(path);
60 this.url = url;
61 includeInBuild = true;
62 contextLength = -1;
63 embeddingOnly = false;
64 embeddingLength = 0;
65 if (!lora)
66 {
67 GGUFReader reader = new GGUFReader(this.path);
68 string arch = reader.GetStringField("general.architecture");
69 if (arch != null)
70 {
71 contextLength = reader.GetIntField($"{arch}.context_length");
72 embeddingLength = reader.GetIntField($"{arch}.embedding_length");
73 }
74 embeddingOnly = embeddingOnlyArchs.Contains(arch);
75 }
76 }
77
83 {
84 ModelEntry entry = (ModelEntry)MemberwiseClone();
85 entry.label = null;
86 entry.path = entry.filename;
87 return entry;
88 }
89 }
90
92 [Serializable]
93 public class LLMManagerStore
94 {
95 public bool downloadOnStart;
96 public List<ModelEntry> modelEntries;
97 public int debugMode;
98 }
100
101 [DefaultExecutionOrder(-2)]
106 public class LLMManager
107 {
108 public static bool downloadOnStart = false;
109 public static List<ModelEntry> modelEntries = new List<ModelEntry>();
110 static List<LLM> llms = new List<LLM>();
111
112 public static float downloadProgress = 1;
113 public static List<Action<float>> downloadProgressCallbacks = new List<Action<float>>();
114 static Task<bool> SetupTask;
115 static readonly object lockObject = new object();
116 static long totalSize;
117 static long currFileSize;
118 static long completedSize;
119
124 public static void SetDownloadProgress(float progress)
125 {
126 downloadProgress = (completedSize + progress * currFileSize) / totalSize;
127 foreach (Action<float> downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress);
128 }
129
134 public static Task<bool> Setup()
135 {
136 lock (lockObject)
137 {
138 if (SetupTask == null) SetupTask = SetupOnce();
139 }
140 return SetupTask;
141 }
142
147 public static async Task<bool> SetupOnce()
148 {
149 await LLMUnitySetup.AndroidExtractAsset(LLMUnitySetup.LLMManagerPath, true);
150 LoadFromDisk();
151
152 List<StringPair> downloads = new List<StringPair>();
153 foreach (ModelEntry modelEntry in modelEntries)
154 {
155 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
156 if (File.Exists(target)) continue;
157
158 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url))
159 {
160 await LLMUnitySetup.AndroidExtractFile(modelEntry.filename);
161 if (!File.Exists(target)) LLMUnitySetup.LogError($"Model {modelEntry.filename} could not be found!");
162 }
163 else
164 {
165 target = LLMUnitySetup.GetDownloadAssetPath(modelEntry.filename);
166 downloads.Add(new StringPair { source = modelEntry.url, target = target });
167 }
168 }
169 if (downloads.Count == 0) return true;
170
171 try
172 {
173 downloadProgress = 0;
174 totalSize = 0;
175 completedSize = 0;
176
178 Dictionary<string, long> fileSizes = new Dictionary<string, long>();
179 foreach (StringPair pair in downloads)
180 {
181 long size = client.GetURLFileSize(pair.source);
182 fileSizes[pair.source] = size;
183 totalSize += size;
184 }
185
186 foreach (StringPair pair in downloads)
187 {
188 currFileSize = fileSizes[pair.source];
189 await LLMUnitySetup.DownloadFile(pair.source, pair.target, false, null, SetDownloadProgress);
190 await LLMUnitySetup.AndroidExtractFile(Path.GetFileName(pair.target));
191 completedSize += currFileSize;
192 }
193
194 completedSize = totalSize;
196 }
197 catch (Exception ex)
198 {
199 LLMUnitySetup.LogError($"Error downloading the models: {ex.Message}");
200 return false;
201 }
202 return true;
203 }
204
210 public static ModelEntry Get(string path)
211 {
212 string filename = Path.GetFileName(path);
213 string fullPath = LLMUnitySetup.GetFullPath(path);
214 foreach (ModelEntry entry in modelEntries)
215 {
216 if (entry.filename == filename || entry.path == fullPath) return entry;
217 }
218 return null;
219 }
220
226 public static string GetAssetPath(string filename)
227 {
228 ModelEntry entry = Get(filename);
229 if (entry == null) return "";
230#if UNITY_EDITOR
231 return entry.path;
232#else
233 return LLMUnitySetup.GetAssetPath(entry.filename);
234#endif
235 }
236
242 public static int Num(bool lora)
243 {
244 int num = 0;
245 foreach (ModelEntry entry in modelEntries)
246 {
247 if (entry.lora == lora) num++;
248 }
249 return num;
250 }
251
256 public static int NumModels()
257 {
258 return Num(false);
259 }
260
265 public static int NumLoras()
266 {
267 return Num(true);
268 }
269
274 public static void Register(LLM llm)
275 {
276 llms.Add(llm);
277 }
278
283 public static void Unregister(LLM llm)
284 {
285 llms.Remove(llm);
286 }
287
291 public static void LoadFromDisk()
292 {
293 if (!File.Exists(LLMUnitySetup.LLMManagerPath)) return;
294 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(File.ReadAllText(LLMUnitySetup.LLMManagerPath));
295 downloadOnStart = store.downloadOnStart;
296 modelEntries = store.modelEntries;
297 LLMUnitySetup.DebugMode = (LLMUnitySetup.DebugModeType)store.debugMode;
298 }
299
300#if UNITY_EDITOR
301 static string LLMManagerPref = "LLMManager";
302
303 [HideInInspector] public static float modelProgress = 1;
304 [HideInInspector] public static float loraProgress = 1;
305
306 [InitializeOnLoadMethod]
307 static void InitializeOnLoad()
308 {
309 Load();
310 }
311
317 public static string AddEntry(ModelEntry entry)
318 {
319 int indexToInsert = modelEntries.Count;
320 if (!entry.lora)
321 {
322 if (modelEntries.Count > 0 && modelEntries[0].lora) indexToInsert = 0;
323 else
324 {
325 for (int i = modelEntries.Count - 1; i >= 0; i--)
326 {
327 if (!modelEntries[i].lora)
328 {
329 indexToInsert = i + 1;
330 break;
331 }
332 }
333 }
334 }
335 modelEntries.Insert(indexToInsert, entry);
336 Save();
337 return entry.filename;
338 }
339
348 public static string AddEntry(string path, bool lora = false, string label = null, string url = null)
349 {
350 return AddEntry(new ModelEntry(path, lora, label, url));
351 }
352
361 public static async Task<string> Download(string url, bool lora = false, bool log = false, string label = null)
362 {
363 foreach (ModelEntry entry in modelEntries)
364 {
365 if (entry.url == url)
366 {
367 if (log) LLMUnitySetup.Log($"Found existing entry for {url}");
368 return entry.filename;
369 }
370 }
371
372 string modelName = Path.GetFileName(url).Split("?")[0];
373 ModelEntry entryPath = Get(modelName);
374 if (entryPath != null)
375 {
376 if (log) LLMUnitySetup.Log($"Found existing entry for {modelName}");
377 return entryPath.filename;
378 }
379
380 string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName);
381 float preModelProgress = modelProgress;
382 float preLoraProgress = loraProgress;
383 try
384 {
385 if (!lora)
386 {
387 modelProgress = 0;
388 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress);
389 }
390 else
391 {
392 loraProgress = 0;
393 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress);
394 }
395 }
396 catch (Exception ex)
397 {
398 modelProgress = preModelProgress;
399 loraProgress = preLoraProgress;
400 LLMUnitySetup.LogError($"Error downloading the model from URL '{url}': " + ex.Message);
401 return null;
402 }
403 return AddEntry(modelPath, lora, label, url);
404 }
405
414 public static string Load(string path, bool lora = false, bool log = false, string label = null)
415 {
416 ModelEntry entry = Get(path);
417 if (entry != null)
418 {
419 if (log) LLMUnitySetup.Log($"Found existing entry for {entry.filename}");
420 return entry.filename;
421 }
422 return AddEntry(path, lora, label);
423 }
424
432 public static async Task<string> DownloadModel(string url, bool log = false, string label = null)
433 {
434 return await Download(url, false, log, label);
435 }
436
444 public static async Task<string> DownloadLora(string url, bool log = false, string label = null)
445 {
446 return await Download(url, true, log, label);
447 }
448
456 public static string LoadModel(string path, bool log = false, string label = null)
457 {
458 return Load(path, false, log, label);
459 }
460
468 public static string LoadLora(string path, bool log = false, string label = null)
469 {
470 return Load(path, true, log, label);
471 }
472
478 public static void SetURL(string filename, string url)
479 {
480 SetURL(Get(filename), url);
481 }
482
488 public static void SetURL(ModelEntry entry, string url)
489 {
490 if (entry == null) return;
491 entry.url = url;
492 Save();
493 }
494
500 public static void SetIncludeInBuild(string filename, bool includeInBuild)
501 {
502 SetIncludeInBuild(Get(filename), includeInBuild);
503 }
504
510 public static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
511 {
512 if (entry == null) return;
513 entry.includeInBuild = includeInBuild;
514 Save();
515 }
516
521 public static void SetDownloadOnStart(bool value)
522 {
523 downloadOnStart = value;
524 if (downloadOnStart)
525 {
526 bool warn = false;
527 foreach (ModelEntry entry in modelEntries)
528 {
529 if (entry.url == null || entry.url == "") warn = true;
530 }
531 if (warn) LLMUnitySetup.LogWarning("Some models do not have a URL and will be copied in the build. To resolve this fill in the URL field in the expanded view of the LLM Model list.");
532 }
533 Save();
534 }
535
540 public static void Remove(string filename)
541 {
542 Remove(Get(filename));
543 }
544
549 public static void Remove(ModelEntry entry)
550 {
551 if (entry == null) return;
552 modelEntries.Remove(entry);
553 Save();
554 foreach (LLM llm in llms)
555 {
556 if (!entry.lora && llm.model == entry.filename) llm.model = "";
557 else if (entry.lora) llm.RemoveLora(entry.filename);
558 }
559 }
560
565 public static void SetModelProgress(float progress)
566 {
567 modelProgress = progress;
568 }
569
574 public static void SetLoraProgress(float progress)
575 {
576 loraProgress = progress;
577 }
578
582 public static void Save()
583 {
584 string json = JsonUtility.ToJson(new LLMManagerStore
585 {
586 modelEntries = modelEntries,
587 downloadOnStart = downloadOnStart,
588 }, true);
589 PlayerPrefs.SetString(LLMManagerPref, json);
590 PlayerPrefs.Save();
591 }
592
596 public static void Load()
597 {
598 string pref = PlayerPrefs.GetString(LLMManagerPref);
599 if (pref == null || pref == "") return;
600 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(pref);
601 downloadOnStart = store.downloadOnStart;
602 modelEntries = store.modelEntries;
603 }
604
608 public static void SaveToDisk()
609 {
610 List<ModelEntry> modelEntriesBuild = new List<ModelEntry>();
611 foreach (ModelEntry modelEntry in modelEntries)
612 {
613 if (!modelEntry.includeInBuild) continue;
614 modelEntriesBuild.Add(modelEntry.OnlyRequiredFields());
615 }
616 string json = JsonUtility.ToJson(new LLMManagerStore
617 {
618 modelEntries = modelEntriesBuild,
619 downloadOnStart = downloadOnStart,
620 debugMode = (int)LLMUnitySetup.DebugMode
621 }, true);
622 File.WriteAllText(LLMUnitySetup.LLMManagerPath, json);
623 }
624
629 public static void Build(Action<string, string> copyCallback)
630 {
631 SaveToDisk();
632
633 foreach (ModelEntry modelEntry in modelEntries)
634 {
635 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
636 if (!modelEntry.includeInBuild || File.Exists(target)) continue;
637 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url)) copyCallback(modelEntry.path, target);
638 }
639 }
640
641#endif
642 }
643}
Class implementing the GGUF reader.
Definition LLMGGUF.cs:55
int GetIntField(string key)
Allows to retrieve an integer GGUF field.
Definition LLMGGUF.cs:157
string GetStringField(string key)
Allows to retrieve a string GGUF field.
Definition LLMGGUF.cs:145
Class implementing the LLM model manager.
static void SaveToDisk()
Saves the model manager to disk for the build.
static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
Sets whether to include a model to the build.
static void SetLoraProgress(float progress)
Sets the LORA download progress.
static void Unregister(LLM llm)
Removes a LLM from the model manager.
static async Task< string > Download(string url, bool lora=false, bool log=false, string label=null)
Downloads a model and adds a model entry to the model manager.
static string GetAssetPath(string filename)
Gets the asset path based on whether the application runs locally in the editor or in a build.
static int NumModels()
Returns the number of LLM models.
static ModelEntry Get(string path)
Gets the model entry for a model path.
static Task< bool > Setup()
Setup of the models.
static int Num(bool lora)
Returns the number of LLM/LORA models.
static void Register(LLM llm)
Registers a LLM to the model manager.
static void SetURL(string filename, string url)
Sets the URL for a model.
static void SetModelProgress(float progress)
Sets the LLM download progress.
static void Save()
Serialises and saves the model manager.
static string AddEntry(string path, bool lora=false, string label=null, string url=null)
Creates and adds a model entry to the model manager.
static void Load()
Deserialises and loads the model manager.
static void SetDownloadOnStart(bool value)
Sets whether to download files on start.
static void SetIncludeInBuild(string filename, bool includeInBuild)
Sets whether to include a model to the build.
static int NumLoras()
Returns the number of LORA models.
static async Task< string > DownloadModel(string url, bool log=false, string label=null)
Downloads a LLM model from disk and adds a model entry to the model manager.
static string LoadModel(string path, bool log=false, string label=null)
Loads a LLM model from disk and adds a model entry to the model manager.
static void Remove(string filename)
Removes a model from the model manager.
static void LoadFromDisk()
Loads the model manager from a file.
static void SetDownloadProgress(float progress)
Sets the model download progress in all registered callbacks.
static string Load(string path, bool lora=false, bool log=false, string label=null)
Loads a model from disk and adds a model entry to the model manager.
static string AddEntry(ModelEntry entry)
Adds a model entry to the model manager.
static async Task< bool > SetupOnce()
Task performing the setup of the models.
static async Task< string > DownloadLora(string url, bool log=false, string label=null)
Downloads a Lora model from disk and adds a model entry to the model manager.
static void Remove(ModelEntry entry)
Removes a model from the model manager.
static string LoadLora(string path, bool log=false, string label=null)
Loads a LORA model from disk and adds a model entry to the model manager.
static void Build(Action< string, string > copyCallback)
Saves the model manager to disk along with models that are not (or can't) be downloaded for the build...
static void SetURL(ModelEntry entry, string url)
Sets the URL for a model.
Class implementing helper functions for setup and process management.
static string LLMManagerPath
Path of file with build information for runtime.
static string modelDownloadPath
Model download path.
Unity MonoBehaviour component that manages a local LLM server instance. Handles model loading,...
Definition LLM.cs:21
void RemoveLora(string path)
Removes a specific LORA adapter.
Definition LLM.cs:685
string model
LLM model file path (.gguf format)
Definition LLM.cs:182
Class implementing a LLM model entry.
Definition LLMManager.cs:18
ModelEntry(string path, bool lora=false, string label=null, string url=null)
Constructs a LLM model entry.
Definition LLMManager.cs:54
static string GetFilenameOrRelativeAssetPath(string path)
Returns the relative asset path if it is in the AssetPath folder (StreamingAssets or persistentPath),...
Definition LLMManager.cs:36
ModelEntry OnlyRequiredFields()
Returns only the required fields for bundling the model in the build.
Definition LLMManager.cs:82
Class implementing a resumable Web client.