LLM for Unity  v2.2.5
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMManager.cs
1using System;
2using System.Collections.Generic;
3using System.IO;
4using System.Threading.Tasks;
5using UnityEditor;
6using UnityEngine;
7
8namespace LLMUnity
9{
10 [Serializable]
11 public class ModelEntry
12 {
13 public string label;
14 public string filename;
15 public string path;
16 public bool lora;
17 public string chatTemplate;
18 public string url;
19 public bool includeInBuild;
20 public int contextLength;
21
22 public static string GetFilenameOrRelativeAssetPath(string path)
23 {
24 string assetPath = LLMUnitySetup.GetAssetPath(path); // Note: this will return the full path if a full path is passed
25 string basePath = LLMUnitySetup.GetAssetPath();
26 if (File.Exists(assetPath) && LLMUnitySetup.IsSubPath(assetPath, basePath))
27 {
28 return LLMUnitySetup.RelativePath(assetPath, basePath);
29 }
30 return Path.GetFileName(path);
31 }
32
33 public ModelEntry(string path, bool lora = false, string label = null, string url = null)
34 {
35 filename = GetFilenameOrRelativeAssetPath(path);
36 this.label = label == null ? Path.GetFileName(filename) : label;
37 this.lora = lora;
38 this.path = LLMUnitySetup.GetFullPath(path);
39 this.url = url;
40 includeInBuild = true;
41 chatTemplate = null;
42 contextLength = -1;
43 if (!lora)
44 {
45 GGUFReader reader = new GGUFReader(this.path);
46 chatTemplate = ChatTemplate.FromGGUF(reader, this.path);
47 string arch = reader.GetStringField("general.architecture");
48 if (arch != null) contextLength = reader.GetIntField($"{arch}.context_length");
49 }
50 }
51
52 public ModelEntry OnlyRequiredFields()
53 {
54 ModelEntry entry = (ModelEntry)MemberwiseClone();
55 entry.label = null;
56 entry.path = entry.filename;
57 return entry;
58 }
59 }
60
61 [Serializable]
62 public class LLMManagerStore
63 {
64 public bool downloadOnStart;
65 public List<ModelEntry> modelEntries;
66 }
67
68 [DefaultExecutionOrder(-2)]
69 public class LLMManager
70 {
71 public static bool downloadOnStart = false;
72 public static List<ModelEntry> modelEntries = new List<ModelEntry>();
73 static List<LLM> llms = new List<LLM>();
74
75 public static float downloadProgress = 1;
76 public static List<Callback<float>> downloadProgressCallbacks = new List<Callback<float>>();
77 static Task<bool> SetupTask;
78 static readonly object lockObject = new object();
79 static long totalSize;
80 static long currFileSize;
81 static long completedSize;
82
83 public static void SetDownloadProgress(float progress)
84 {
85 downloadProgress = (completedSize + progress * currFileSize) / totalSize;
86 foreach (Callback<float> downloadProgressCallback in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress);
87 }
88
89 public static Task<bool> Setup()
90 {
91 lock (lockObject)
92 {
93 if (SetupTask == null) SetupTask = SetupOnce();
94 }
95 return SetupTask;
96 }
97
98 public static async Task<bool> SetupOnce()
99 {
100 await LLMUnitySetup.AndroidExtractAsset(LLMUnitySetup.LLMManagerPath, true);
101 LoadFromDisk();
102
103 List<StringPair> downloads = new List<StringPair>();
104 foreach (ModelEntry modelEntry in modelEntries)
105 {
106 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
107 if (File.Exists(target)) continue;
108
109 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url))
110 {
111 await LLMUnitySetup.AndroidExtractFile(modelEntry.filename);
112 if (!File.Exists(target)) LLMUnitySetup.LogError($"Model {modelEntry.filename} could not be found!");
113 }
114 else
115 {
116 downloads.Add(new StringPair {source = modelEntry.url, target = target});
117 }
118 }
119 if (downloads.Count == 0) return true;
120
121 try
122 {
123 downloadProgress = 0;
124 totalSize = 0;
125 completedSize = 0;
126
128 Dictionary<string, long> fileSizes = new Dictionary<string, long>();
129 foreach (StringPair pair in downloads)
130 {
131 long size = client.GetURLFileSize(pair.source);
132 fileSizes[pair.source] = size;
133 totalSize += size;
134 }
135
136 foreach (StringPair pair in downloads)
137 {
138 currFileSize = fileSizes[pair.source];
139 await LLMUnitySetup.DownloadFile(pair.source, pair.target, false, null, SetDownloadProgress);
140 await LLMUnitySetup.AndroidExtractFile(Path.GetFileName(pair.target));
141 completedSize += currFileSize;
142 }
143
144 completedSize = totalSize;
145 SetDownloadProgress(0);
146 }
147 catch (Exception ex)
148 {
149 LLMUnitySetup.LogError($"Error downloading the models: {ex.Message}");
150 return false;
151 }
152 return true;
153 }
154
155 public static void SetTemplate(string filename, string chatTemplate)
156 {
157 SetTemplate(Get(filename), chatTemplate);
158 }
159
160 public static void SetTemplate(ModelEntry entry, string chatTemplate)
161 {
162 if (entry == null) return;
163 entry.chatTemplate = chatTemplate;
164 foreach (LLM llm in llms)
165 {
166 if (llm != null && llm.model == entry.filename) llm.SetTemplate(chatTemplate);
167 }
168#if UNITY_EDITOR
169 Save();
170#endif
171 }
172
173 public static ModelEntry Get(string path)
174 {
175 string filename = Path.GetFileName(path);
176 string fullPath = LLMUnitySetup.GetFullPath(path);
177 foreach (ModelEntry entry in modelEntries)
178 {
179 if (entry.filename == filename || entry.path == fullPath) return entry;
180 }
181 return null;
182 }
183
184 public static string GetAssetPath(string filename)
185 {
186 ModelEntry entry = Get(filename);
187 if (entry == null) return "";
188#if UNITY_EDITOR
189 return entry.path;
190#else
191 return LLMUnitySetup.GetAssetPath(entry.filename);
192#endif
193 }
194
195 public static int Num(bool lora)
196 {
197 int num = 0;
198 foreach (ModelEntry entry in modelEntries)
199 {
200 if (entry.lora == lora) num++;
201 }
202 return num;
203 }
204
205 public static int NumModels()
206 {
207 return Num(false);
208 }
209
210 public static int NumLoras()
211 {
212 return Num(true);
213 }
214
215 public static void Register(LLM llm)
216 {
217 llms.Add(llm);
218 }
219
220 public static void Unregister(LLM llm)
221 {
222 llms.Remove(llm);
223 }
224
225 public static void LoadFromDisk()
226 {
227 if (!File.Exists(LLMUnitySetup.LLMManagerPath)) return;
228 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(File.ReadAllText(LLMUnitySetup.LLMManagerPath));
229 downloadOnStart = store.downloadOnStart;
230 modelEntries = store.modelEntries;
231 }
232
233#if UNITY_EDITOR
234 static string LLMManagerPref = "LLMManager";
235
236 [HideInInspector] public static float modelProgress = 1;
237 [HideInInspector] public static float loraProgress = 1;
238
239 [InitializeOnLoadMethod]
240 static void InitializeOnLoad()
241 {
242 Load();
243 }
244
245 public static string AddEntry(ModelEntry entry)
246 {
247 int indexToInsert = modelEntries.Count;
248 if (!entry.lora)
249 {
250 if (modelEntries.Count > 0 && modelEntries[0].lora) indexToInsert = 0;
251 else
252 {
253 for (int i = modelEntries.Count - 1; i >= 0; i--)
254 {
255 if (!modelEntries[i].lora)
256 {
257 indexToInsert = i + 1;
258 break;
259 }
260 }
261 }
262 }
263 modelEntries.Insert(indexToInsert, entry);
264 Save();
265 return entry.filename;
266 }
267
268 public static string AddEntry(string path, bool lora = false, string label = null, string url = null)
269 {
270 return AddEntry(new ModelEntry(path, lora, label, url));
271 }
272
273 public static async Task<string> Download(string url, bool lora = false, bool log = false, string label = null)
274 {
275 foreach (ModelEntry entry in modelEntries)
276 {
277 if (entry.url == url)
278 {
279 if (log) LLMUnitySetup.Log($"Found existing entry for {url}");
280 return entry.filename;
281 }
282 }
283
284 string modelName = Path.GetFileName(url).Split("?")[0];
285 ModelEntry entryPath = Get(modelName);
286 if (entryPath != null)
287 {
288 if (log) LLMUnitySetup.Log($"Found existing entry for {modelName}");
289 return entryPath.filename;
290 }
291
292 string modelPath = Path.Combine(LLMUnitySetup.modelDownloadPath, modelName);
293 float preModelProgress = modelProgress;
294 float preLoraProgress = loraProgress;
295 try
296 {
297 if (!lora)
298 {
299 modelProgress = 0;
300 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetModelProgress);
301 }
302 else
303 {
304 loraProgress = 0;
305 await LLMUnitySetup.DownloadFile(url, modelPath, false, null, SetLoraProgress);
306 }
307 }
308 catch (Exception ex)
309 {
310 modelProgress = preModelProgress;
311 loraProgress = preLoraProgress;
312 LLMUnitySetup.LogError($"Error downloading the model from URL '{url}': " + ex.Message);
313 return null;
314 }
315 return AddEntry(modelPath, lora, label, url);
316 }
317
318 public static string Load(string path, bool lora = false, bool log = false, string label = null)
319 {
320 ModelEntry entry = Get(path);
321 if (entry != null)
322 {
323 if (log) LLMUnitySetup.Log($"Found existing entry for {entry.filename}");
324 return entry.filename;
325 }
326 return AddEntry(path, lora, label);
327 }
328
329 public static async Task<string> DownloadModel(string url, bool log = false, string label = null)
330 {
331 return await Download(url, false, log, label);
332 }
333
334 public static async Task<string> DownloadLora(string url, bool log = false, string label = null)
335 {
336 return await Download(url, true, log, label);
337 }
338
339 public static string LoadModel(string path, bool log = false, string label = null)
340 {
341 return Load(path, false, log, label);
342 }
343
344 public static string LoadLora(string path, bool log = false, string label = null)
345 {
346 return Load(path, true, log, label);
347 }
348
349 public static void SetURL(string filename, string url)
350 {
351 SetURL(Get(filename), url);
352 }
353
354 public static void SetURL(ModelEntry entry, string url)
355 {
356 if (entry == null) return;
357 entry.url = url;
358 Save();
359 }
360
361 public static void SetIncludeInBuild(string filename, bool includeInBuild)
362 {
363 SetIncludeInBuild(Get(filename), includeInBuild);
364 }
365
366 public static void SetIncludeInBuild(ModelEntry entry, bool includeInBuild)
367 {
368 if (entry == null) return;
369 entry.includeInBuild = includeInBuild;
370 Save();
371 }
372
373 public static void SetDownloadOnStart(bool value)
374 {
375 downloadOnStart = value;
376 if (downloadOnStart)
377 {
378 bool warn = false;
379 foreach (ModelEntry entry in modelEntries)
380 {
381 if (entry.url == null || entry.url == "") warn = true;
382 }
383 if (warn) LLMUnitySetup.LogWarning("Some models do not have a URL and will be copied in the build. To resolve this fill in the URL field in the expanded view of the LLM Model list.");
384 }
385 Save();
386 }
387
388 public static void Remove(string filename)
389 {
390 Remove(Get(filename));
391 }
392
393 public static void Remove(ModelEntry entry)
394 {
395 if (entry == null) return;
396 modelEntries.Remove(entry);
397 Save();
398 foreach (LLM llm in llms)
399 {
400 if (!entry.lora && llm.model == entry.filename) llm.model = "";
401 else if (entry.lora) llm.RemoveLora(entry.filename);
402 }
403 }
404
405 public static void SetModelProgress(float progress)
406 {
407 modelProgress = progress;
408 }
409
410 public static void SetLoraProgress(float progress)
411 {
412 loraProgress = progress;
413 }
414
415 public static void Save()
416 {
417 string json = JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntries, downloadOnStart = downloadOnStart }, true);
418 PlayerPrefs.SetString(LLMManagerPref, json);
419 PlayerPrefs.Save();
420 }
421
422 public static void Load()
423 {
424 string pref = PlayerPrefs.GetString(LLMManagerPref);
425 if (pref == null || pref == "") return;
426 LLMManagerStore store = JsonUtility.FromJson<LLMManagerStore>(pref);
427 downloadOnStart = store.downloadOnStart;
428 modelEntries = store.modelEntries;
429 }
430
431 public static void SaveToDisk()
432 {
433 List<ModelEntry> modelEntriesBuild = new List<ModelEntry>();
434 foreach (ModelEntry modelEntry in modelEntries)
435 {
436 if (!modelEntry.includeInBuild) continue;
437 modelEntriesBuild.Add(modelEntry.OnlyRequiredFields());
438 }
439 string json = JsonUtility.ToJson(new LLMManagerStore { modelEntries = modelEntriesBuild, downloadOnStart = downloadOnStart }, true);
440 File.WriteAllText(LLMUnitySetup.LLMManagerPath, json);
441 }
442
443 public static void Build(ActionCallback copyCallback)
444 {
445 SaveToDisk();
446
447 foreach (ModelEntry modelEntry in modelEntries)
448 {
449 string target = LLMUnitySetup.GetAssetPath(modelEntry.filename);
450 if (!modelEntry.includeInBuild || File.Exists(target)) continue;
451 if (!downloadOnStart || string.IsNullOrEmpty(modelEntry.url)) copyCallback(modelEntry.path, target);
452 }
453 }
454
455#endif
456 }
457}
Class implementing the skeleton of a chat template.
static string FromGGUF(string path)
Determines the chat template name from a GGUF file. It reads the GGUF file and then determines the ch...
Class implementing the GGUF reader.
Definition LLMGGUF.cs:55
int GetIntField(string key)
Allows to retrieve an integer GGUF field.
Definition LLMGGUF.cs:152
string GetStringField(string key)
Allows to retrieve a string GGUF field.
Definition LLMGGUF.cs:140
Class implementing helper functions for setup and process management.
static string LLMManagerPath
Path of file with build information for runtime.
static string modelDownloadPath
Model download path.
Class implementing the LLM server.
Definition LLM.cs:19
void RemoveLora(string path)
Allows to remove a LORA model from the LLM. Models supported are in .gguf format.
Definition LLM.cs:260
string model
the LLM model to use. Models with .gguf format are allowed.
Definition LLM.cs:56
void SetTemplate(string templateName, bool setDirty=true)
Set the chat template for the LLM.
Definition LLM.cs:313