LLM for Unity  v2.2.5
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMLib.cs
Go to the documentation of this file.
1
4using System;
5using System.Collections.Generic;
6using System.IO;
7using System.Runtime.InteropServices;
8using UnityEngine;
9
10namespace LLMUnity
11{
12 public class StreamWrapper
13 {
14 LLMLib llmlib;
15 Callback<string> callback;
16 IntPtr stringWrapper;
17 string previousString = "";
18 string previousCalledString = "";
19 int previousBufferSize = 0;
20 bool clearOnUpdate;
21
22 public StreamWrapper(LLMLib llmlib, Callback<string> callback, bool clearOnUpdate = false)
23 {
24 this.llmlib = llmlib;
25 this.callback = callback;
26 this.clearOnUpdate = clearOnUpdate;
27 stringWrapper = (llmlib?.StringWrapper_Construct()).GetValueOrDefault();
28 }
29
30 public string GetString(bool clear = false)
31 {
32 string result;
33 int bufferSize = (llmlib?.StringWrapper_GetStringSize(stringWrapper)).GetValueOrDefault();
34 if (bufferSize <= 1)
35 {
36 result = "";
37 }
38 else if (previousBufferSize != bufferSize)
39 {
40 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
41 try
42 {
43 llmlib?.StringWrapper_GetString(stringWrapper, buffer, bufferSize, clear);
44 result = Marshal.PtrToStringAnsi(buffer);
45 }
46 finally
47 {
48 Marshal.FreeHGlobal(buffer);
49 }
50 previousString = result;
51 }
52 else
53 {
54 result = previousString;
55 }
56 previousBufferSize = bufferSize;
57 return result;
58 }
59
60 public void Update()
61 {
62 if (stringWrapper == IntPtr.Zero) return;
63 string result = GetString(clearOnUpdate);
64 if (result != "" && previousCalledString != result)
65 {
66 callback?.Invoke(result);
67 previousCalledString = result;
68 }
69 }
70
71 public IntPtr GetStringWrapper()
72 {
73 return stringWrapper;
74 }
75
76 public void Destroy()
77 {
78 if (stringWrapper != IntPtr.Zero) llmlib?.StringWrapper_Delete(stringWrapper);
79 }
80 }
81
82 static class LibraryLoader
83 {
84 // LibraryLoader is adapted from SkiaForUnity:
85 // https://github.com/ammariqais/SkiaForUnity/blob/f43322218c736d1c41f3a3df9355b90db4259a07/SkiaUnity/Assets/SkiaSharp/SkiaSharp-Bindings/SkiaSharp.HarfBuzz.Shared/HarfBuzzSharp.Shared/LibraryLoader.cs
86 public static T GetSymbolDelegate<T>(IntPtr library, string name) where T : Delegate
87 {
88 var symbol = GetSymbol(library, name);
89 if (symbol == IntPtr.Zero)
90 throw new EntryPointNotFoundException($"Unable to load symbol '{name}'.");
91
92 return Marshal.GetDelegateForFunctionPointer<T>(symbol);
93 }
94
95 public static IntPtr LoadLibrary(string libraryName)
96 {
97 if (string.IsNullOrEmpty(libraryName))
98 throw new ArgumentNullException(nameof(libraryName));
99
100 IntPtr handle;
101 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
102 handle = Win32.LoadLibrary(libraryName);
103 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
104 handle = Linux.dlopen(libraryName);
105 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
106 handle = Mac.dlopen(libraryName);
107 else if (Application.platform == RuntimePlatform.Android)
108 handle = Android.dlopen(libraryName);
109 else
110 throw new PlatformNotSupportedException($"Current platform is unknown, unable to load library '{libraryName}'.");
111
112 return handle;
113 }
114
115 public static IntPtr GetSymbol(IntPtr library, string symbolName)
116 {
117 if (string.IsNullOrEmpty(symbolName))
118 throw new ArgumentNullException(nameof(symbolName));
119
120 IntPtr handle;
121 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
122 handle = Win32.GetProcAddress(library, symbolName);
123 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
124 handle = Linux.dlsym(library, symbolName);
125 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
126 handle = Mac.dlsym(library, symbolName);
127 else if (Application.platform == RuntimePlatform.Android)
128 handle = Android.dlsym(library, symbolName);
129 else
130 throw new PlatformNotSupportedException($"Current platform is unknown, unable to load symbol '{symbolName}' from library {library}.");
131
132 return handle;
133 }
134
135 public static void FreeLibrary(IntPtr library)
136 {
137 if (library == IntPtr.Zero)
138 return;
139
140 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
141 Win32.FreeLibrary(library);
142 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
143 Linux.dlclose(library);
144 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
145 Mac.dlclose(library);
146 else if (Application.platform == RuntimePlatform.Android)
147 Android.dlclose(library);
148 else
149 throw new PlatformNotSupportedException($"Current platform is unknown, unable to close library '{library}'.");
150 }
151
152 private static class Mac
153 {
154 private const string SystemLibrary = "/usr/lib/libSystem.dylib";
155
156 private const int RTLD_LAZY = 1;
157 private const int RTLD_NOW = 2;
158
159 public static IntPtr dlopen(string path, bool lazy = true) =>
160 dlopen(path, lazy ? RTLD_LAZY : RTLD_NOW);
161
162 [DllImport(SystemLibrary)]
163 public static extern IntPtr dlopen(string path, int mode);
164
165 [DllImport(SystemLibrary)]
166 public static extern IntPtr dlsym(IntPtr handle, string symbol);
167
168 [DllImport(SystemLibrary)]
169 public static extern void dlclose(IntPtr handle);
170 }
171
172 private static class Linux
173 {
174 private const string SystemLibrary = "libdl.so";
175 private const string SystemLibrary2 = "libdl.so.2"; // newer Linux distros use this
176
177 private const int RTLD_LAZY = 1;
178 private const int RTLD_NOW = 2;
179
180 private static bool UseSystemLibrary2 = true;
181
182 public static IntPtr dlopen(string path, bool lazy = true)
183 {
184 try
185 {
186 return dlopen2(path, lazy ? RTLD_LAZY : RTLD_NOW);
187 }
188 catch (DllNotFoundException)
189 {
190 UseSystemLibrary2 = false;
191 return dlopen1(path, lazy ? RTLD_LAZY : RTLD_NOW);
192 }
193 }
194
195 public static IntPtr dlsym(IntPtr handle, string symbol)
196 {
197 return UseSystemLibrary2 ? dlsym2(handle, symbol) : dlsym1(handle, symbol);
198 }
199
200 public static void dlclose(IntPtr handle)
201 {
202 if (UseSystemLibrary2)
203 dlclose2(handle);
204 else
205 dlclose1(handle);
206 }
207
208 [DllImport(SystemLibrary, EntryPoint = "dlopen")]
209 private static extern IntPtr dlopen1(string path, int mode);
210
211 [DllImport(SystemLibrary, EntryPoint = "dlsym")]
212 private static extern IntPtr dlsym1(IntPtr handle, string symbol);
213
214 [DllImport(SystemLibrary, EntryPoint = "dlclose")]
215 private static extern void dlclose1(IntPtr handle);
216
217 [DllImport(SystemLibrary2, EntryPoint = "dlopen")]
218 private static extern IntPtr dlopen2(string path, int mode);
219
220 [DllImport(SystemLibrary2, EntryPoint = "dlsym")]
221 private static extern IntPtr dlsym2(IntPtr handle, string symbol);
222
223 [DllImport(SystemLibrary2, EntryPoint = "dlclose")]
224 private static extern void dlclose2(IntPtr handle);
225 }
226
227 private static class Win32
228 {
229 private const string SystemLibrary = "Kernel32.dll";
230
231 [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)]
232 public static extern IntPtr LoadLibrary(string lpFileName);
233
234 [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)]
235 public static extern IntPtr GetProcAddress(IntPtr hModule, string lpProcName);
236
237 [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)]
238 public static extern void FreeLibrary(IntPtr hModule);
239 }
240
241 private static class Android
242 {
243 public static IntPtr dlopen(string path) => dlopen(path, 1);
244
245#if UNITY_ANDROID
246 // LoadLibrary for Android
247 [DllImport("__Internal")]
248 public static extern IntPtr dlopen(string filename, int flags);
249
250 // GetSymbol for Android
251 [DllImport("__Internal")]
252 public static extern IntPtr dlsym(IntPtr handle, string symbol);
253
254 // FreeLibrary for Android
255 [DllImport("__Internal")]
256 public static extern int dlclose(IntPtr handle);
257#else
258 public static IntPtr dlopen(string filename, int flags)
259 {
260 return default;
261 }
262
263 public static IntPtr dlsym(IntPtr handle, string symbol)
264 {
265 return default;
266 }
267
268 public static int dlclose(IntPtr handle)
269 {
270 return default;
271 }
272
273#endif
274 }
275 }
276
277 public class LLMLib
278 {
279 IntPtr libraryHandle = IntPtr.Zero;
280 static readonly object staticLock = new object();
281 static bool has_avx = false;
282 static bool has_avx2 = false;
283 static bool has_avx512 = false;
284 static bool has_avx_set = false;
285
286 static LLMLib()
287 {
288 lock (staticLock)
289 {
290 if (has_avx_set) return;
291 string archCheckerPath = GetArchitectureCheckerPath();
292 if (archCheckerPath != null)
293 {
294 IntPtr archCheckerHandle = LibraryLoader.LoadLibrary(archCheckerPath);
295 if (archCheckerHandle == IntPtr.Zero)
296 {
297 LLMUnitySetup.LogError($"Failed to load library {archCheckerPath}.");
298 }
299 else
300 {
301 try
302 {
303 has_avx = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle, "has_avx")();
304 has_avx2 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle, "has_avx2")();
305 has_avx512 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle, "has_avx512")();
306 LibraryLoader.FreeLibrary(archCheckerHandle);
307 }
308 catch (Exception e)
309 {
310 LLMUnitySetup.LogError($"{e.GetType()}: {e.Message}");
311 }
312 }
313 }
314 has_avx_set = true;
315 }
316 }
317
318 public LLMLib(string arch)
319 {
320 libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch));
321 if (libraryHandle == IntPtr.Zero)
322 {
323 throw new Exception($"Failed to load library {arch}.");
324 }
325
326 LLM_Construct = LibraryLoader.GetSymbolDelegate<LLM_ConstructDelegate>(libraryHandle, "LLM_Construct");
327 LLM_Delete = LibraryLoader.GetSymbolDelegate<LLM_DeleteDelegate>(libraryHandle, "LLM_Delete");
328 LLM_StartServer = LibraryLoader.GetSymbolDelegate<LLM_StartServerDelegate>(libraryHandle, "LLM_StartServer");
329 LLM_StopServer = LibraryLoader.GetSymbolDelegate<LLM_StopServerDelegate>(libraryHandle, "LLM_StopServer");
330 LLM_Start = LibraryLoader.GetSymbolDelegate<LLM_StartDelegate>(libraryHandle, "LLM_Start");
331 LLM_Started = LibraryLoader.GetSymbolDelegate<LLM_StartedDelegate>(libraryHandle, "LLM_Started");
332 LLM_Stop = LibraryLoader.GetSymbolDelegate<LLM_StopDelegate>(libraryHandle, "LLM_Stop");
333 LLM_SetTemplate = LibraryLoader.GetSymbolDelegate<LLM_SetTemplateDelegate>(libraryHandle, "LLM_SetTemplate");
334 LLM_SetSSL = LibraryLoader.GetSymbolDelegate<LLM_SetSSLDelegate>(libraryHandle, "LLM_SetSSL");
335 LLM_Tokenize = LibraryLoader.GetSymbolDelegate<LLM_TokenizeDelegate>(libraryHandle, "LLM_Tokenize");
336 LLM_Detokenize = LibraryLoader.GetSymbolDelegate<LLM_DetokenizeDelegate>(libraryHandle, "LLM_Detokenize");
337 LLM_Embeddings = LibraryLoader.GetSymbolDelegate<LLM_EmbeddingsDelegate>(libraryHandle, "LLM_Embeddings");
338 LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate<LLM_LoraWeightDelegate>(libraryHandle, "LLM_Lora_Weight");
339 LLM_LoraList = LibraryLoader.GetSymbolDelegate<LLM_LoraListDelegate>(libraryHandle, "LLM_Lora_List");
340 LLM_Completion = LibraryLoader.GetSymbolDelegate<LLM_CompletionDelegate>(libraryHandle, "LLM_Completion");
341 LLM_Slot = LibraryLoader.GetSymbolDelegate<LLM_SlotDelegate>(libraryHandle, "LLM_Slot");
342 LLM_Cancel = LibraryLoader.GetSymbolDelegate<LLM_CancelDelegate>(libraryHandle, "LLM_Cancel");
343 LLM_Status = LibraryLoader.GetSymbolDelegate<LLM_StatusDelegate>(libraryHandle, "LLM_Status");
344 StringWrapper_Construct = LibraryLoader.GetSymbolDelegate<StringWrapper_ConstructDelegate>(libraryHandle, "StringWrapper_Construct");
345 StringWrapper_Delete = LibraryLoader.GetSymbolDelegate<StringWrapper_DeleteDelegate>(libraryHandle, "StringWrapper_Delete");
346 StringWrapper_GetStringSize = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringSizeDelegate>(libraryHandle, "StringWrapper_GetStringSize");
347 StringWrapper_GetString = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringDelegate>(libraryHandle, "StringWrapper_GetString");
348 Logging = LibraryLoader.GetSymbolDelegate<LoggingDelegate>(libraryHandle, "Logging");
349 StopLogging = LibraryLoader.GetSymbolDelegate<StopLoggingDelegate>(libraryHandle, "StopLogging");
350 }
351
352 public void Destroy()
353 {
354 if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle);
355 }
356
357 public static List<string> PossibleArchitectures(bool gpu = false)
358 {
359 List<string> architectures = new List<string>();
360 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer ||
361 Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
362 {
363 if (gpu)
364 {
365 if (LLMUnitySetup.FullLlamaLib)
366 {
367 architectures.Add("cuda-cu12.2.0-full");
368 architectures.Add("cuda-cu11.7.1-full");
369 architectures.Add("hip-full");
370 }
371 else
372 {
373 architectures.Add("cuda-cu12.2.0");
374 architectures.Add("cuda-cu11.7.1");
375 architectures.Add("hip");
376 }
377 architectures.Add("vulkan");
378 }
379 if (has_avx512) architectures.Add("avx512");
380 if (has_avx2) architectures.Add("avx2");
381 if (has_avx) architectures.Add("avx");
382 architectures.Add("noavx");
383 }
384 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
385 {
386 string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower();
387 if (arch.Contains("arm"))
388 {
389 architectures.Add("arm64-acc");
390 architectures.Add("arm64-no_acc");
391 }
392 else
393 {
394 if (arch != "x86" && arch != "x64") LLMUnitySetup.LogWarning($"Unknown architecture of processor {arch}! Falling back to x86_64");
395 architectures.Add("x64-acc");
396 architectures.Add("x64-no_acc");
397 }
398 }
399 else if (Application.platform == RuntimePlatform.Android)
400 {
401 architectures.Add("android");
402 }
403 else
404 {
405 string error = "Unknown OS";
406 LLMUnitySetup.LogError(error);
407 throw new Exception(error);
408 }
409 return architectures;
410 }
411
412 public static string GetArchitectureCheckerPath()
413 {
414 string filename;
415 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
416 {
417 filename = $"windows-archchecker/archchecker.dll";
418 }
419 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
420 {
421 filename = $"linux-archchecker/libarchchecker.so";
422 }
423 else
424 {
425 return null;
426 }
427 return Path.Combine(LLMUnitySetup.libraryPath, filename);
428 }
429
430 public static string GetArchitecturePath(string arch)
431 {
432 string filename;
433 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
434 {
435 filename = $"windows-{arch}/undreamai_windows-{arch}.dll";
436 }
437 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
438 {
439 filename = $"linux-{arch}/libundreamai_linux-{arch}.so";
440 }
441 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
442 {
443 filename = $"macos-{arch}/libundreamai_macos-{arch}.dylib";
444 }
445 else if (Application.platform == RuntimePlatform.Android)
446 {
447 return "libundreamai_android.so";
448 }
449 else
450 {
451 string error = "Unknown OS";
452 LLMUnitySetup.LogError(error);
453 throw new Exception(error);
454 }
455 return Path.Combine(LLMUnitySetup.libraryPath, filename);
456 }
457
458 public string GetStringWrapperResult(IntPtr stringWrapper)
459 {
460 string result = "";
461 int bufferSize = StringWrapper_GetStringSize(stringWrapper);
462 if (bufferSize > 1)
463 {
464 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
465 try
466 {
467 StringWrapper_GetString(stringWrapper, buffer, bufferSize);
468 result = Marshal.PtrToStringAnsi(buffer);
469 }
470 finally
471 {
472 Marshal.FreeHGlobal(buffer);
473 }
474 }
475 return result;
476 }
477
478 public delegate bool HasArchDelegate();
479 public delegate void LoggingDelegate(IntPtr stringWrapper);
480 public delegate void StopLoggingDelegate();
481 public delegate IntPtr LLM_ConstructDelegate(string command);
482 public delegate void LLM_DeleteDelegate(IntPtr LLMObject);
483 public delegate void LLM_StartServerDelegate(IntPtr LLMObject);
484 public delegate void LLM_StopServerDelegate(IntPtr LLMObject);
485 public delegate void LLM_StartDelegate(IntPtr LLMObject);
486 public delegate bool LLM_StartedDelegate(IntPtr LLMObject);
487 public delegate void LLM_StopDelegate(IntPtr LLMObject);
488 public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate);
489 public delegate void LLM_SetSSLDelegate(IntPtr LLMObject, string SSLCert, string SSLKey);
490 public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
491 public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
492 public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
493 public delegate void LLM_LoraWeightDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
494 public delegate void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper);
495 public delegate void LLM_CompletionDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
496 public delegate void LLM_SlotDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
497 public delegate void LLM_CancelDelegate(IntPtr LLMObject, int idSlot);
498 public delegate int LLM_StatusDelegate(IntPtr LLMObject, IntPtr stringWrapper);
499 public delegate IntPtr StringWrapper_ConstructDelegate();
500 public delegate void StringWrapper_DeleteDelegate(IntPtr instance);
501 public delegate int StringWrapper_GetStringSizeDelegate(IntPtr instance);
502 public delegate void StringWrapper_GetStringDelegate(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false);
503
504 public LoggingDelegate Logging;
505 public StopLoggingDelegate StopLogging;
506 public LLM_ConstructDelegate LLM_Construct;
507 public LLM_DeleteDelegate LLM_Delete;
508 public LLM_StartServerDelegate LLM_StartServer;
509 public LLM_StopServerDelegate LLM_StopServer;
510 public LLM_StartDelegate LLM_Start;
511 public LLM_StartedDelegate LLM_Started;
512 public LLM_StopDelegate LLM_Stop;
513 public LLM_SetTemplateDelegate LLM_SetTemplate;
514 public LLM_SetSSLDelegate LLM_SetSSL;
515 public LLM_TokenizeDelegate LLM_Tokenize;
516 public LLM_DetokenizeDelegate LLM_Detokenize;
517 public LLM_CompletionDelegate LLM_Completion;
518 public LLM_EmbeddingsDelegate LLM_Embeddings;
519 public LLM_LoraWeightDelegate LLM_Lora_Weight;
520 public LLM_LoraListDelegate LLM_LoraList;
521 public LLM_SlotDelegate LLM_Slot;
522 public LLM_CancelDelegate LLM_Cancel;
523 public LLM_StatusDelegate LLM_Status;
524 public StringWrapper_ConstructDelegate StringWrapper_Construct;
525 public StringWrapper_DeleteDelegate StringWrapper_Delete;
526 public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize;
527 public StringWrapper_GetStringDelegate StringWrapper_GetString;
528 }
529}