LLM for Unity  v2.3.0
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
Search.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.IO.Compression;
7using System.Linq;
8using System.Runtime.Serialization.Formatters.Binary;
9using System.Threading.Tasks;
10using UnityEditor;
11using UnityEngine;
12
14namespace LLMUnity
15{
20 [DefaultExecutionOrder(-2)]
21 public abstract class Searchable : MonoBehaviour
22 {
28 public abstract string Get(int key);
29
36 public abstract Task<int> Add(string inputString, string group = "");
37
44 public abstract int Remove(string inputString, string group = "");
45
50 public abstract void Remove(int key);
51
56 public abstract int Count();
57
63 public abstract int Count(string group);
64
68 public abstract void Clear();
69
76 public abstract Task<int> IncrementalSearch(string queryString, string group = "");
77
92 public abstract ValueTuple<int[], float[], bool> IncrementalFetchKeys(int fetchKey, int k);
93
98 public abstract void IncrementalSearchComplete(int fetchKey);
99
115 public async Task<(string[], float[])> Search(string queryString, int k, string group = "")
116 {
117 int fetchKey = await IncrementalSearch(queryString, group);
118 (string[] phrases, float[] distances, bool completed) = IncrementalFetch(fetchKey, k);
119 if (!completed) IncrementalSearchComplete(fetchKey);
120 return (phrases, distances);
121 }
122
137 public virtual ValueTuple<string[], float[], bool> IncrementalFetch(int fetchKey, int k)
138 {
139 (int[] resultKeys, float[] distances, bool completed) = IncrementalFetchKeys(fetchKey, k);
140 string[] results = new string[resultKeys.Length];
141 for (int i = 0; i < resultKeys.Length; i++) results[i] = Get(resultKeys[i]);
142 return (results, distances, completed);
143 }
144
149 public void Save(string filePath)
150 {
151 try
152 {
153 string path = LLMUnitySetup.GetAssetPath(filePath);
154 ArchiveSaver.Save(path, Save);
155 }
156 catch (Exception e)
157 {
158 LLMUnitySetup.LogError($"File {filePath} could not be saved due to {e.GetType()}: {e.Message}");
159 }
160 }
161
166 public async Task<bool> Load(string filePath)
167 {
168 try
169 {
170 await LLMUnitySetup.AndroidExtractAsset(filePath, true);
171 string path = LLMUnitySetup.GetAssetPath(filePath);
172 if (!File.Exists(path)) return false;
173 ArchiveSaver.Load(path, Load);
174 }
175 catch (Exception e)
176 {
177 LLMUnitySetup.LogError($"File {filePath} could not be loaded due to {e.GetType()}: {e.Message}");
178 return false;
179 }
180 return true;
181 }
182
184 public abstract void Save(ZipArchive archive);
185 public abstract void Load(ZipArchive archive);
186 public virtual string GetSavePath(string name)
187 {
188 return Path.Combine(GetType().Name, name);
189 }
190
191 public virtual void UpdateGameObjects() {}
192
193 protected T ConstructComponent<T>(Type type, Action<T, T> copyAction = null) where T : Component
194 {
195 T Construct(Type type)
196 {
197 if (type == null) return null;
198 T newComponent = (T)gameObject.AddComponent(type);
199 if (newComponent is Searchable searchable) searchable.UpdateGameObjects();
200 return newComponent;
201 }
202
203 T component = (T)gameObject.GetComponent(typeof(T));
204 T newComponent;
205 if (component == null)
206 {
207 newComponent = Construct(type);
208 }
209 else
210 {
211 if (component.GetType() == type)
212 {
213 newComponent = component;
214 }
215 else
216 {
217 newComponent = Construct(type);
218 if (type != null) copyAction?.Invoke(component, newComponent);
219#if UNITY_EDITOR
220 DestroyImmediate(component);
221#else
222 Destroy(component);
223#endif
224 }
225 }
226 return newComponent;
227 }
228
229 public virtual void Awake()
230 {
231 UpdateGameObjects();
232 }
233
234#if UNITY_EDITOR
235 public virtual void Reset()
236 {
237 if (!Application.isPlaying) EditorApplication.update += UpdateGameObjects;
238 }
239
240 public virtual void OnDestroy()
241 {
242 if (!Application.isPlaying) EditorApplication.update -= UpdateGameObjects;
243 }
244
245#endif
247 }
248
253 public abstract class SearchMethod : Searchable
254 {
255 public LLMEmbedder llmEmbedder;
256
257 protected int nextKey = 0;
258 protected int nextIncrementalSearchKey = 0;
259 protected SortedDictionary<int, string> data = new SortedDictionary<int, string>();
260 protected SortedDictionary<string, List<int>> dataSplits = new SortedDictionary<string, List<int>>();
261
262 protected LLM llm;
263
264 protected abstract void AddInternal(int key, float[] embedding);
265 protected abstract void RemoveInternal(int key);
266 protected abstract void ClearInternal();
267 protected abstract void SaveInternal(ZipArchive archive);
268 protected abstract void LoadInternal(ZipArchive archive);
269
274 public void SetLLM(LLM llm)
275 {
276 this.llm = llm;
277 if (llmEmbedder != null) llmEmbedder.llm = llm;
278 }
279
293 public async Task<(string[], float[])> SearchFromList(string query, string[] searchList)
294 {
295 float[] embedding = await Encode(query);
296 float[][] embeddingsList = new float[searchList.Length][];
297 for (int i = 0; i < searchList.Length; i++) embeddingsList[i] = await Encode(searchList[i]);
298
299 float[] unsortedDistances = InverseDotProduct(embedding, embeddingsList);
300 List<(string, float)> sortedLists = searchList.Zip(unsortedDistances, (first, second) => (first, second))
301 .OrderBy(item => item.Item2)
302 .ToList();
303
304 string[] results = new string[sortedLists.Count];
305 float[] distances = new float[sortedLists.Count];
306 for (int i = 0; i < sortedLists.Count; i++)
307 {
308 results[i] = sortedLists[i].Item1;
309 distances[i] = sortedLists[i].Item2;
310 }
311 return (results.ToArray(), distances.ToArray());
312 }
313
315 public static float DotProduct(float[] vector1, float[] vector2)
316 {
317 if (vector1 == null || vector2 == null) throw new ArgumentNullException("Vectors cannot be null");
318 if (vector1.Length != vector2.Length) throw new ArgumentException("Vector lengths must be equal for dot product calculation");
319 float result = 0;
320 for (int i = 0; i < vector1.Length; i++)
321 {
322 result += vector1[i] * vector2[i];
323 }
324 return result;
325 }
326
327 public static float InverseDotProduct(float[] vector1, float[] vector2)
328 {
329 return 1 - DotProduct(vector1, vector2);
330 }
331
332 public static float[] InverseDotProduct(float[] vector1, float[][] vector2)
333 {
334 float[] results = new float[vector2.Length];
335 for (int i = 0; i < vector2.Length; i++)
336 {
337 results[i] = InverseDotProduct(vector1, vector2[i]);
338 }
339 return results;
340 }
341
342 public virtual async Task<float[]> Encode(string inputString)
343 {
344 return (await llmEmbedder.Embeddings(inputString)).ToArray();
345 }
346
347 public virtual async Task<List<int>> Tokenize(string query, Callback<List<int>> callback = null)
348 {
349 return await llmEmbedder.Tokenize(query, callback);
350 }
351
352 public async Task<string> Detokenize(List<int> tokens, Callback<string> callback = null)
353 {
354 return await llmEmbedder.Detokenize(tokens, callback);
355 }
356
357 public override string Get(int key)
358 {
359 if (data.TryGetValue(key, out string result)) return result;
360 return null;
361 }
362
363 public override async Task<int> Add(string inputString, string group = "")
364 {
365 int key = nextKey++;
366 AddInternal(key, await Encode(inputString));
367
368 data[key] = inputString;
369 if (!dataSplits.ContainsKey(group)) dataSplits[group] = new List<int>(){key};
370 else dataSplits[group].Add(key);
371 return key;
372 }
373
374 public override void Clear()
375 {
376 data.Clear();
377 dataSplits.Clear();
378 ClearInternal();
379 nextKey = 0;
380 nextIncrementalSearchKey = 0;
381 }
382
383 protected bool RemoveEntry(int key)
384 {
385 bool removed = data.Remove(key);
386 if (removed) RemoveInternal(key);
387 return removed;
388 }
389
390 public override void Remove(int key)
391 {
392 if (RemoveEntry(key))
393 {
394 foreach (var dataSplit in dataSplits.Values) dataSplit.Remove(key);
395 }
396 }
397
398 public override int Remove(string inputString, string group = "")
399 {
400 if (!dataSplits.TryGetValue(group, out List<int> dataSplit)) return 0;
401 List<int> removeIds = new List<int>();
402 foreach (int key in dataSplit)
403 {
404 if (Get(key) == inputString) removeIds.Add(key);
405 }
406 foreach (int key in removeIds)
407 {
408 if (RemoveEntry(key)) dataSplit.Remove(key);
409 }
410 return removeIds.Count;
411 }
412
413 public override int Count()
414 {
415 return data.Count;
416 }
417
418 public override int Count(string group)
419 {
420 if (!dataSplits.TryGetValue(group, out List<int> dataSplit)) return 0;
421 return dataSplit.Count;
422 }
423
424 public override async Task<int> IncrementalSearch(string queryString, string group = "")
425 {
426 return IncrementalSearch(await Encode(queryString), group);
427 }
428
429 public override void Save(ZipArchive archive)
430 {
431 ArchiveSaver.Save(archive, data, GetSavePath("data"));
432 ArchiveSaver.Save(archive, dataSplits, GetSavePath("dataSplits"));
433 ArchiveSaver.Save(archive, nextKey, GetSavePath("nextKey"));
434 ArchiveSaver.Save(archive, nextIncrementalSearchKey, GetSavePath("nextIncrementalSearchKey"));
435 SaveInternal(archive);
436 }
437
438 public override void Load(ZipArchive archive)
439 {
440 data = ArchiveSaver.Load<SortedDictionary<int, string>>(archive, GetSavePath("data"));
441 dataSplits = ArchiveSaver.Load<SortedDictionary<string, List<int>>>(archive, GetSavePath("dataSplits"));
442 nextKey = ArchiveSaver.Load<int>(archive, GetSavePath("nextKey"));
443 nextIncrementalSearchKey = ArchiveSaver.Load<int>(archive, GetSavePath("nextIncrementalSearchKey"));
444 LoadInternal(archive);
445 }
446
447 public override void UpdateGameObjects()
448 {
449 if (this == null || llmEmbedder != null) return;
450 llmEmbedder = ConstructComponent<LLMEmbedder>(typeof(LLMEmbedder), (previous, current) => current.llm = previous.llm);
451 }
452
453 public abstract int IncrementalSearch(float[] embedding, string group = "");
455 }
456
461 public abstract class SearchPlugin : Searchable
462 {
463 protected SearchMethod search;
464
469 public void SetSearch(SearchMethod search)
470 {
471 this.search = search;
472 }
473
475 protected abstract void SaveInternal(ZipArchive archive);
476 protected abstract void LoadInternal(ZipArchive archive);
477
478 public override void Save(ZipArchive archive)
479 {
480 search.Save(archive);
481 SaveInternal(archive);
482 }
483
484 public override void Load(ZipArchive archive)
485 {
486 search.Load(archive);
487 LoadInternal(archive);
488 }
489
491 }
492
494 public class ArchiveSaver
495 {
496 public delegate void ArchiveSaverCallback(ZipArchive archive);
497
498 public static void Save(string filePath, ArchiveSaverCallback callback)
499 {
500 using (FileStream stream = new FileStream(filePath, FileMode.Create))
501 using (ZipArchive archive = new ZipArchive(stream, ZipArchiveMode.Create))
502 {
503 callback(archive);
504 }
505 }
506
507 public static void Load(string filePath, ArchiveSaverCallback callback)
508 {
509 using (FileStream stream = new FileStream(filePath, FileMode.Open))
510 using (ZipArchive archive = new ZipArchive(stream, ZipArchiveMode.Read))
511 {
512 callback(archive);
513 }
514 }
515
516 public static void Save(ZipArchive archive, object saveObject, string name)
517 {
518 ZipArchiveEntry mainEntry = archive.CreateEntry(name);
519 using (Stream entryStream = mainEntry.Open())
520 {
521 BinaryFormatter formatter = new BinaryFormatter();
522 formatter.Serialize(entryStream, saveObject);
523 }
524 }
525
526 public static T Load<T>(ZipArchive archive, string name)
527 {
528 ZipArchiveEntry baseEntry = archive.GetEntry(name);
529 if (baseEntry == null) throw new Exception($"No entry with name {name} was found");
530 using (Stream entryStream = baseEntry.Open())
531 {
532 BinaryFormatter formatter = new BinaryFormatter();
533 return (T)formatter.Deserialize(entryStream);
534 }
535 }
536 }
538}
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
Definition LLMCaller.cs:337
virtual async Task< List< float > > Embeddings(string query, Callback< List< float > > callback=null)
Computes the embeddings of the provided input.
Definition LLMCaller.cs:367
virtual async Task< string > Detokenize(List< int > tokens, Callback< string > callback=null)
Detokenises the provided tokens to a string.
Definition LLMCaller.cs:352
Class implementing the LLM embedder.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19
Class implementing the search method template.
Definition Search.cs:254
void SetLLM(LLM llm)
Sets the LLM for encoding the search entries.
Definition Search.cs:274
async Task<(string[], float[])> SearchFromList(string query, string[] searchList)
Orders the entries in the searchList according to their similarity to the provided query....
Definition Search.cs:293
Class implementing the search plugin template used e.g. in chunking.
Definition Search.cs:462
void SetSearch(SearchMethod search)
Sets the search method of the plugin.
Definition Search.cs:469
Class implementing the search template.
Definition Search.cs:22
async Task<(string[], float[])> Search(string queryString, int k, string group="")
Search for similar results to the provided query. The most similar results and their distances (dissi...
Definition Search.cs:115
int Remove(string inputString, string group="")
Removes a phrase from the search.
int Count(string group)
Returns a count of the phrases in a specific data group.
Task< int > Add(string inputString, string group="")
Adds a phrase to the search.
int Count()
Returns a count of the phrases.
void Clear()
Clears the search object.
ValueTuple< int[], float[], bool > IncrementalFetchKeys(int fetchKey, int k)
Retrieves the most similar search results in batches (incremental search). The phrase keys and distan...
virtual ValueTuple< string[], float[], bool > IncrementalFetch(int fetchKey, int k)
Retrieves the most similar search results in batches (incremental search). The most similar results a...
Definition Search.cs:137
void IncrementalSearchComplete(int fetchKey)
Completes the search and clears the cached results for an incremental search.
string Get(int key)
Retrieves the phrase with the specific id.
void Save(string filePath)
Saves the state of the search object.
Definition Search.cs:149
void Remove(int key)
Removes a phrase from the search.
async Task< bool > Load(string filePath)
Loads the state of the search object.
Definition Search.cs:166
Task< int > IncrementalSearch(string queryString, string group="")
Allows to do search and retrieve results in batches (incremental search).