4using System.Collections.Generic;
6using System.IO.Compression;
8using System.Runtime.Serialization.Formatters.Binary;
9using System.Threading.Tasks;
20 [DefaultExecutionOrder(-2)]
28 public abstract string Get(
int key);
36 public abstract Task<int>
Add(
string inputString,
string group =
"");
44 public abstract int Remove(
string inputString,
string group =
"");
50 public abstract void Remove(
int key);
63 public abstract int Count(
string group);
115 public async Task<(
string[],
float[])>
Search(
string queryString,
int k,
string group =
"")
118 (
string[] phrases,
float[] distances,
bool completed) =
IncrementalFetch(fetchKey, k);
120 return (phrases, distances);
140 string[] results =
new string[resultKeys.Length];
141 for (
int i = 0; i < resultKeys.Length; i++) results[i] =
Get(resultKeys[i]);
142 return (results, distances, completed);
149 public void Save(
string filePath)
154 ArchiveSaver.Save(path,
Save);
158 LLMUnitySetup.LogError($
"File {filePath} could not be saved due to {e.GetType()}: {e.Message}");
166 public async Task<bool>
Load(
string filePath)
172 if (!File.Exists(path))
return false;
173 ArchiveSaver.Load(path,
Load);
177 LLMUnitySetup.LogError($
"File {filePath} could not be loaded due to {e.GetType()}: {e.Message}");
184 public abstract void Save(ZipArchive archive);
185 public abstract void Load(ZipArchive archive);
186 public virtual string GetSavePath(
string name)
188 return Path.Combine(GetType().Name, name);
191 public virtual void UpdateGameObjects() {}
193 protected T ConstructComponent<T>(Type type, Action<T, T> copyAction =
null) where T : Component
195 T Construct(Type type)
197 if (type ==
null)
return null;
198 T newComponent = (T)gameObject.AddComponent(type);
199 if (newComponent is Searchable searchable) searchable.UpdateGameObjects();
203 T component = (T)gameObject.GetComponent(typeof(T));
205 if (component ==
null)
207 newComponent = Construct(type);
211 if (component.GetType() == type)
213 newComponent = component;
217 newComponent = Construct(type);
218 if (type !=
null) copyAction?.Invoke(component, newComponent);
220 DestroyImmediate(component);
229 public virtual void Awake()
235 public virtual void Reset()
237 if (!Application.isPlaying) EditorApplication.update += UpdateGameObjects;
240 public virtual void OnDestroy()
242 if (!Application.isPlaying) EditorApplication.update -= UpdateGameObjects;
257 protected int nextKey = 0;
258 protected int nextIncrementalSearchKey = 0;
259 protected SortedDictionary<int, string> data =
new SortedDictionary<int, string>();
260 protected SortedDictionary<string, List<int>> dataSplits =
new SortedDictionary<string, List<int>>();
264 protected abstract void AddInternal(
int key,
float[] embedding);
265 protected abstract void RemoveInternal(
int key);
266 protected abstract void ClearInternal();
267 protected abstract void SaveInternal(ZipArchive archive);
268 protected abstract void LoadInternal(ZipArchive archive);
277 if (llmEmbedder !=
null) llmEmbedder.llm = llm;
293 public async Task<(
string[],
float[])>
SearchFromList(
string query,
string[] searchList)
295 float[] embedding = await Encode(query);
296 float[][] embeddingsList =
new float[searchList.Length][];
297 for (
int i = 0; i < searchList.Length; i++) embeddingsList[i] = await Encode(searchList[i]);
299 float[] unsortedDistances = InverseDotProduct(embedding, embeddingsList);
300 List<(string, float)> sortedLists = searchList.Zip(unsortedDistances, (first, second) => (first, second))
301 .OrderBy(item => item.Item2)
304 string[] results =
new string[sortedLists.Count];
305 float[] distances =
new float[sortedLists.Count];
306 for (
int i = 0; i < sortedLists.Count; i++)
308 results[i] = sortedLists[i].Item1;
309 distances[i] = sortedLists[i].Item2;
311 return (results.ToArray(), distances.ToArray());
315 public static float DotProduct(
float[] vector1,
float[] vector2)
317 if (vector1 ==
null || vector2 ==
null)
throw new ArgumentNullException(
"Vectors cannot be null");
318 if (vector1.Length != vector2.Length)
throw new ArgumentException(
"Vector lengths must be equal for dot product calculation");
320 for (
int i = 0; i < vector1.Length; i++)
322 result += vector1[i] * vector2[i];
327 public static float InverseDotProduct(
float[] vector1,
float[] vector2)
329 return 1 - DotProduct(vector1, vector2);
332 public static float[] InverseDotProduct(
float[] vector1,
float[][] vector2)
334 float[] results =
new float[vector2.Length];
335 for (
int i = 0; i < vector2.Length; i++)
337 results[i] = InverseDotProduct(vector1, vector2[i]);
342 public virtual async Task<float[]> Encode(
string inputString)
344 return (await llmEmbedder.
Embeddings(inputString)).ToArray();
347 public virtual async Task<List<int>> Tokenize(
string query, Callback<List<int>> callback =
null)
349 return await llmEmbedder.
Tokenize(query, callback);
352 public async Task<string> Detokenize(List<int> tokens, Callback<string> callback =
null)
354 return await llmEmbedder.
Detokenize(tokens, callback);
357 public override string Get(
int key)
359 if (data.TryGetValue(key, out
string result))
return result;
363 public override async Task<int>
Add(
string inputString,
string group =
"")
366 AddInternal(key, await Encode(inputString));
368 data[key] = inputString;
369 if (!dataSplits.ContainsKey(group)) dataSplits[group] =
new List<int>(){key};
370 else dataSplits[group].Add(key);
374 public override void Clear()
380 nextIncrementalSearchKey = 0;
383 protected bool RemoveEntry(
int key)
385 bool removed = data.Remove(key);
386 if (removed) RemoveInternal(key);
390 public override void Remove(
int key)
392 if (RemoveEntry(key))
394 foreach (var dataSplit
in dataSplits.Values) dataSplit.Remove(key);
398 public override int Remove(
string inputString,
string group =
"")
400 if (!dataSplits.TryGetValue(group, out List<int> dataSplit))
return 0;
401 List<int> removeIds =
new List<int>();
402 foreach (
int key
in dataSplit)
404 if (
Get(key) == inputString) removeIds.Add(key);
406 foreach (
int key
in removeIds)
408 if (RemoveEntry(key)) dataSplit.Remove(key);
410 return removeIds.Count;
413 public override int Count()
418 public override int Count(
string group)
420 if (!dataSplits.TryGetValue(group, out List<int> dataSplit))
return 0;
421 return dataSplit.Count;
424 public override async Task<int>
IncrementalSearch(
string queryString,
string group =
"")
429 public override void Save(ZipArchive archive)
431 ArchiveSaver.Save(archive, data, GetSavePath(
"data"));
432 ArchiveSaver.Save(archive, dataSplits, GetSavePath(
"dataSplits"));
433 ArchiveSaver.Save(archive, nextKey, GetSavePath(
"nextKey"));
434 ArchiveSaver.Save(archive, nextIncrementalSearchKey, GetSavePath(
"nextIncrementalSearchKey"));
435 SaveInternal(archive);
438 public override void Load(ZipArchive archive)
440 data = ArchiveSaver.Load<SortedDictionary<int, string>>(archive, GetSavePath(
"data"));
441 dataSplits = ArchiveSaver.Load<SortedDictionary<string, List<int>>>(archive, GetSavePath(
"dataSplits"));
442 nextKey = ArchiveSaver.Load<
int>(archive, GetSavePath(
"nextKey"));
443 nextIncrementalSearchKey = ArchiveSaver.Load<
int>(archive, GetSavePath(
"nextIncrementalSearchKey"));
444 LoadInternal(archive);
447 public override void UpdateGameObjects()
449 if (
this ==
null || llmEmbedder !=
null)
return;
450 llmEmbedder = ConstructComponent<LLMEmbedder>(typeof(LLMEmbedder), (previous, current) => current.llm = previous.llm);
471 this.search = search;
475 protected abstract void SaveInternal(ZipArchive archive);
476 protected abstract void LoadInternal(ZipArchive archive);
478 public override void Save(ZipArchive archive)
480 search.
Save(archive);
481 SaveInternal(archive);
484 public override void Load(ZipArchive archive)
486 search.
Load(archive);
487 LoadInternal(archive);
494 public class ArchiveSaver
496 public delegate
void ArchiveSaverCallback(ZipArchive archive);
498 public static void Save(
string filePath, ArchiveSaverCallback callback)
500 using (FileStream stream =
new FileStream(filePath, FileMode.Create))
501 using (ZipArchive archive =
new ZipArchive(stream, ZipArchiveMode.Create))
507 public static void Load(
string filePath, ArchiveSaverCallback callback)
509 using (FileStream stream =
new FileStream(filePath, FileMode.Open))
510 using (ZipArchive archive =
new ZipArchive(stream, ZipArchiveMode.Read))
516 public static void Save(ZipArchive archive,
object saveObject,
string name)
518 ZipArchiveEntry mainEntry = archive.CreateEntry(name);
519 using (Stream entryStream = mainEntry.Open())
521 BinaryFormatter formatter =
new BinaryFormatter();
522 formatter.Serialize(entryStream, saveObject);
526 public static T Load<T>(ZipArchive archive,
string name)
528 ZipArchiveEntry baseEntry = archive.GetEntry(name);
529 if (baseEntry ==
null)
throw new Exception($
"No entry with name {name} was found");
530 using (Stream entryStream = baseEntry.Open())
532 BinaryFormatter formatter =
new BinaryFormatter();
533 return (T)formatter.Deserialize(entryStream);
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
virtual async Task< List< float > > Embeddings(string query, Callback< List< float > > callback=null)
Computes the embeddings of the provided input.
virtual async Task< string > Detokenize(List< int > tokens, Callback< string > callback=null)
Detokenises the provided tokens to a string.
Class implementing the LLM embedder.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Class implementing the search method template.
void SetLLM(LLM llm)
Sets the LLM for encoding the search entries.
async Task<(string[], float[])> SearchFromList(string query, string[] searchList)
Orders the entries in the searchList according to their similarity to the provided query....
Class implementing the search plugin template used e.g. in chunking.
void SetSearch(SearchMethod search)
Sets the search method of the plugin.
Class implementing the search template.
async Task<(string[], float[])> Search(string queryString, int k, string group="")
Search for similar results to the provided query. The most similar results and their distances (dissi...
int Remove(string inputString, string group="")
Removes a phrase from the search.
int Count(string group)
Returns a count of the phrases in a specific data group.
Task< int > Add(string inputString, string group="")
Adds a phrase to the search.
int Count()
Returns a count of the phrases.
void Clear()
Clears the search object.
ValueTuple< int[], float[], bool > IncrementalFetchKeys(int fetchKey, int k)
Retrieves the most similar search results in batches (incremental search). The phrase keys and distan...
virtual ValueTuple< string[], float[], bool > IncrementalFetch(int fetchKey, int k)
Retrieves the most similar search results in batches (incremental search). The most similar results a...
void IncrementalSearchComplete(int fetchKey)
Completes the search and clears the cached results for an incremental search.
string Get(int key)
Retrieves the phrase with the specific id.
void Save(string filePath)
Saves the state of the search object.
void Remove(int key)
Removes a phrase from the search.
async Task< bool > Load(string filePath)
Loads the state of the search object.
Task< int > IncrementalSearch(string queryString, string group="")
Allows to do search and retrieve results in batches (incremental search).