71 public static bool downloadOnStart =
false;
72 public static List<ModelEntry> modelEntries =
new List<ModelEntry>();
73 static List<LLM> llms =
new List<LLM>();
75 public static float downloadProgress = 1;
76 public static List<Callback<float>> downloadProgressCallbacks =
new List<Callback<float>>();
77 static Task<bool> SetupTask;
78 static readonly
object lockObject =
new object();
79 static long totalSize;
80 static long currFileSize;
81 static long completedSize;
83 public static void SetDownloadProgress(
float progress)
85 downloadProgress = (completedSize + progress * currFileSize) / totalSize;
86 foreach (Callback<float> downloadProgressCallback
in downloadProgressCallbacks) downloadProgressCallback?.Invoke(downloadProgress);
89 public static Task<bool> Setup()
93 if (SetupTask ==
null) SetupTask = SetupOnce();
98 public static async Task<bool> SetupOnce()
103 List<StringPair> downloads =
new List<StringPair>();
104 foreach (
ModelEntry modelEntry
in modelEntries)
106 string target =
LLMUnitySetup.GetAssetPath(modelEntry.filename);
107 if (File.Exists(target))
continue;
109 if (!downloadOnStart ||
string.IsNullOrEmpty(modelEntry.url))
112 if (!File.Exists(target))
LLMUnitySetup.LogError($
"Model {modelEntry.filename} could not be found!");
116 downloads.Add(
new StringPair {source = modelEntry.url, target = target});
119 if (downloads.Count == 0)
return true;
123 downloadProgress = 0;
128 Dictionary<string, long> fileSizes =
new Dictionary<string, long>();
129 foreach (StringPair pair
in downloads)
131 long size = client.GetURLFileSize(pair.source);
132 fileSizes[pair.source] = size;
136 foreach (StringPair pair
in downloads)
138 currFileSize = fileSizes[pair.source];
139 await
LLMUnitySetup.DownloadFile(pair.source, pair.target,
false,
null, SetDownloadProgress);
140 await
LLMUnitySetup.AndroidExtractFile(Path.GetFileName(pair.target));
141 completedSize += currFileSize;
144 completedSize = totalSize;
145 SetDownloadProgress(0);
149 LLMUnitySetup.LogError($
"Error downloading the models: {ex.Message}");
155 public static void SetTemplate(
string filename,
string chatTemplate)
157 SetTemplate(Get(filename), chatTemplate);
160 public static void SetTemplate(
ModelEntry entry,
string chatTemplate)
162 if (entry ==
null)
return;
163 entry.chatTemplate = chatTemplate;
164 foreach (
LLM llm
in llms)
166 if (llm !=
null && llm.
model == entry.filename) llm.
SetTemplate(chatTemplate);
175 string filename = Path.GetFileName(path);
179 if (entry.filename == filename || entry.path == fullPath)
return entry;
184 public static string GetAssetPath(
string filename)
187 if (entry ==
null)
return "";
195 public static int Num(
bool lora)
200 if (entry.lora == lora) num++;
205 public static int NumModels()
210 public static int NumLoras()
215 public static void Register(
LLM llm)
220 public static void Unregister(
LLM llm)
225 public static void LoadFromDisk()
229 downloadOnStart = store.downloadOnStart;
230 modelEntries = store.modelEntries;
234 static string LLMManagerPref =
"LLMManager";
236 [HideInInspector]
public static float modelProgress = 1;
237 [HideInInspector]
public static float loraProgress = 1;
239 [InitializeOnLoadMethod]
240 static void InitializeOnLoad()
245 public static string AddEntry(
ModelEntry entry)
247 int indexToInsert = modelEntries.Count;
250 if (modelEntries.Count > 0 && modelEntries[0].lora) indexToInsert = 0;
253 for (
int i = modelEntries.Count - 1; i >= 0; i--)
255 if (!modelEntries[i].lora)
257 indexToInsert = i + 1;
263 modelEntries.Insert(indexToInsert, entry);
265 return entry.filename;
268 public static string AddEntry(
string path,
bool lora =
false,
string label =
null,
string url =
null)
270 return AddEntry(
new ModelEntry(path, lora, label, url));
273 public static async Task<string> Download(
string url,
bool lora =
false,
bool log =
false,
string label =
null)
277 if (entry.url == url)
279 if (log)
LLMUnitySetup.Log($
"Found existing entry for {url}");
280 return entry.filename;
284 string modelName = Path.GetFileName(url).Split(
"?")[0];
286 if (entryPath !=
null)
288 if (log)
LLMUnitySetup.Log($
"Found existing entry for {modelName}");
289 return entryPath.filename;
293 float preModelProgress = modelProgress;
294 float preLoraProgress = loraProgress;
300 await
LLMUnitySetup.DownloadFile(url, modelPath,
false,
null, SetModelProgress);
305 await
LLMUnitySetup.DownloadFile(url, modelPath,
false,
null, SetLoraProgress);
310 modelProgress = preModelProgress;
311 loraProgress = preLoraProgress;
312 LLMUnitySetup.LogError($
"Error downloading the model from URL '{url}': " + ex.Message);
315 return AddEntry(modelPath, lora, label, url);
318 public static string Load(
string path,
bool lora =
false,
bool log =
false,
string label =
null)
323 if (log)
LLMUnitySetup.Log($
"Found existing entry for {entry.filename}");
324 return entry.filename;
326 return AddEntry(path, lora, label);
329 public static async Task<string> DownloadModel(
string url,
bool log =
false,
string label =
null)
331 return await Download(url,
false, log, label);
334 public static async Task<string> DownloadLora(
string url,
bool log =
false,
string label =
null)
336 return await Download(url,
true, log, label);
339 public static string LoadModel(
string path,
bool log =
false,
string label =
null)
341 return Load(path,
false, log, label);
344 public static string LoadLora(
string path,
bool log =
false,
string label =
null)
346 return Load(path,
true, log, label);
349 public static void SetURL(
string filename,
string url)
351 SetURL(Get(filename), url);
354 public static void SetURL(
ModelEntry entry,
string url)
356 if (entry ==
null)
return;
361 public static void SetIncludeInBuild(
string filename,
bool includeInBuild)
363 SetIncludeInBuild(Get(filename), includeInBuild);
366 public static void SetIncludeInBuild(
ModelEntry entry,
bool includeInBuild)
368 if (entry ==
null)
return;
369 entry.includeInBuild = includeInBuild;
373 public static void SetDownloadOnStart(
bool value)
375 downloadOnStart = value;
381 if (entry.url ==
null || entry.url ==
"") warn =
true;
383 if (warn)
LLMUnitySetup.LogWarning(
"Some models do not have a URL and will be copied in the build. To resolve this fill in the URL field in the expanded view of the LLM Model list.");
388 public static void Remove(
string filename)
390 Remove(Get(filename));
395 if (entry ==
null)
return;
396 modelEntries.Remove(entry);
398 foreach (
LLM llm
in llms)
400 if (!entry.lora && llm.
model == entry.filename) llm.model =
"";
401 else if (entry.lora) llm.
RemoveLora(entry.filename);
405 public static void SetModelProgress(
float progress)
407 modelProgress = progress;
410 public static void SetLoraProgress(
float progress)
412 loraProgress = progress;
415 public static void Save()
417 string json = JsonUtility.ToJson(
new LLMManagerStore { modelEntries = modelEntries, downloadOnStart = downloadOnStart },
true);
418 PlayerPrefs.SetString(LLMManagerPref, json);
422 public static void Load()
424 string pref = PlayerPrefs.GetString(LLMManagerPref);
425 if (pref ==
null || pref ==
"")
return;
427 downloadOnStart = store.downloadOnStart;
428 modelEntries = store.modelEntries;
431 public static void SaveToDisk()
433 List<ModelEntry> modelEntriesBuild =
new List<ModelEntry>();
434 foreach (
ModelEntry modelEntry
in modelEntries)
436 if (!modelEntry.includeInBuild)
continue;
437 modelEntriesBuild.Add(modelEntry.OnlyRequiredFields());
439 string json = JsonUtility.ToJson(
new LLMManagerStore { modelEntries = modelEntriesBuild, downloadOnStart = downloadOnStart },
true);
443 public static void Build(ActionCallback copyCallback)
447 foreach (
ModelEntry modelEntry
in modelEntries)
449 string target =
LLMUnitySetup.GetAssetPath(modelEntry.filename);
450 if (!modelEntry.includeInBuild || File.Exists(target))
continue;
451 if (!downloadOnStart ||
string.IsNullOrEmpty(modelEntry.url)) copyCallback(modelEntry.path, target);