LLM for Unity  v2.3.0
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMLib.cs
Go to the documentation of this file.
1
4using System;
5using System.Collections.Generic;
6using System.IO;
7using System.Runtime.InteropServices;
8using UnityEngine;
9
10namespace LLMUnity
11{
16 public class StreamWrapper
17 {
18 LLMLib llmlib;
19 Callback<string> callback;
20 IntPtr stringWrapper;
21 string previousString = "";
22 string previousCalledString = "";
23 int previousBufferSize = 0;
24 bool clearOnUpdate;
25
26 public StreamWrapper(LLMLib llmlib, Callback<string> callback, bool clearOnUpdate = false)
27 {
28 this.llmlib = llmlib;
29 this.callback = callback;
30 this.clearOnUpdate = clearOnUpdate;
31 stringWrapper = (llmlib?.StringWrapper_Construct()).GetValueOrDefault();
32 }
33
39 public string GetString(bool clear = false)
40 {
41 string result;
42 int bufferSize = (llmlib?.StringWrapper_GetStringSize(stringWrapper)).GetValueOrDefault();
43 if (bufferSize <= 1)
44 {
45 result = "";
46 }
47 else if (previousBufferSize != bufferSize)
48 {
49 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
50 try
51 {
52 llmlib?.StringWrapper_GetString(stringWrapper, buffer, bufferSize, clear);
53 result = Marshal.PtrToStringAnsi(buffer);
54 }
55 finally
56 {
57 Marshal.FreeHGlobal(buffer);
58 }
59 previousString = result;
60 }
61 else
62 {
63 result = previousString;
64 }
65 previousBufferSize = bufferSize;
66 return result;
67 }
68
72 public void Update()
73 {
74 if (stringWrapper == IntPtr.Zero) return;
75 string result = GetString(clearOnUpdate);
76 if (result != "" && previousCalledString != result)
77 {
78 callback?.Invoke(result);
79 previousCalledString = result;
80 }
81 }
82
87 public IntPtr GetStringWrapper()
88 {
89 return stringWrapper;
90 }
91
95 public void Destroy()
96 {
97 if (stringWrapper != IntPtr.Zero) llmlib?.StringWrapper_Delete(stringWrapper);
98 }
99 }
100
107 static class LibraryLoader
108 {
116 public static T GetSymbolDelegate<T>(IntPtr library, string name) where T : Delegate
117 {
118 var symbol = GetSymbol(library, name);
119 if (symbol == IntPtr.Zero)
120 throw new EntryPointNotFoundException($"Unable to load symbol '{name}'.");
121
122 return Marshal.GetDelegateForFunctionPointer<T>(symbol);
123 }
124
130 public static IntPtr LoadLibrary(string libraryName)
131 {
132 if (string.IsNullOrEmpty(libraryName))
133 throw new ArgumentNullException(nameof(libraryName));
134
135 IntPtr handle;
136 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
137 handle = Win32.LoadLibrary(libraryName);
138 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
139 handle = Linux.dlopen(libraryName);
140 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
141 handle = Mac.dlopen(libraryName);
142 else if (Application.platform == RuntimePlatform.Android)
143 handle = Android.dlopen(libraryName);
144 else
145 throw new PlatformNotSupportedException($"Current platform is unknown, unable to load library '{libraryName}'.");
146
147 return handle;
148 }
149
156 public static IntPtr GetSymbol(IntPtr library, string symbolName)
157 {
158 if (string.IsNullOrEmpty(symbolName))
159 throw new ArgumentNullException(nameof(symbolName));
160
161 IntPtr handle;
162 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
163 handle = Win32.GetProcAddress(library, symbolName);
164 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
165 handle = Linux.dlsym(library, symbolName);
166 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
167 handle = Mac.dlsym(library, symbolName);
168 else if (Application.platform == RuntimePlatform.Android)
169 handle = Android.dlsym(library, symbolName);
170 else
171 throw new PlatformNotSupportedException($"Current platform is unknown, unable to load symbol '{symbolName}' from library {library}.");
172
173 return handle;
174 }
175
180 public static void FreeLibrary(IntPtr library)
181 {
182 if (library == IntPtr.Zero)
183 return;
184
185 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
186 Win32.FreeLibrary(library);
187 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
188 Linux.dlclose(library);
189 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
190 Mac.dlclose(library);
191 else if (Application.platform == RuntimePlatform.Android)
192 Android.dlclose(library);
193 else
194 throw new PlatformNotSupportedException($"Current platform is unknown, unable to close library '{library}'.");
195 }
196
197 private static class Mac
198 {
199 private const string SystemLibrary = "/usr/lib/libSystem.dylib";
200
201 private const int RTLD_LAZY = 1;
202 private const int RTLD_NOW = 2;
203
204 public static IntPtr dlopen(string path, bool lazy = true) =>
205 dlopen(path, lazy ? RTLD_LAZY : RTLD_NOW);
206
207 [DllImport(SystemLibrary)]
208 public static extern IntPtr dlopen(string path, int mode);
209
210 [DllImport(SystemLibrary)]
211 public static extern IntPtr dlsym(IntPtr handle, string symbol);
212
213 [DllImport(SystemLibrary)]
214 public static extern void dlclose(IntPtr handle);
215 }
216
217 private static class Linux
218 {
219 private const string SystemLibrary = "libdl.so";
220 private const string SystemLibrary2 = "libdl.so.2"; // newer Linux distros use this
221
222 private const int RTLD_LAZY = 1;
223 private const int RTLD_NOW = 2;
224
225 private static bool UseSystemLibrary2 = true;
226
227 public static IntPtr dlopen(string path, bool lazy = true)
228 {
229 try
230 {
231 return dlopen2(path, lazy ? RTLD_LAZY : RTLD_NOW);
232 }
233 catch (DllNotFoundException)
234 {
235 UseSystemLibrary2 = false;
236 return dlopen1(path, lazy ? RTLD_LAZY : RTLD_NOW);
237 }
238 }
239
240 public static IntPtr dlsym(IntPtr handle, string symbol)
241 {
242 return UseSystemLibrary2 ? dlsym2(handle, symbol) : dlsym1(handle, symbol);
243 }
244
245 public static void dlclose(IntPtr handle)
246 {
247 if (UseSystemLibrary2)
248 dlclose2(handle);
249 else
250 dlclose1(handle);
251 }
252
253 [DllImport(SystemLibrary, EntryPoint = "dlopen")]
254 private static extern IntPtr dlopen1(string path, int mode);
255
256 [DllImport(SystemLibrary, EntryPoint = "dlsym")]
257 private static extern IntPtr dlsym1(IntPtr handle, string symbol);
258
259 [DllImport(SystemLibrary, EntryPoint = "dlclose")]
260 private static extern void dlclose1(IntPtr handle);
261
262 [DllImport(SystemLibrary2, EntryPoint = "dlopen")]
263 private static extern IntPtr dlopen2(string path, int mode);
264
265 [DllImport(SystemLibrary2, EntryPoint = "dlsym")]
266 private static extern IntPtr dlsym2(IntPtr handle, string symbol);
267
268 [DllImport(SystemLibrary2, EntryPoint = "dlclose")]
269 private static extern void dlclose2(IntPtr handle);
270 }
271
272 private static class Win32
273 {
274 private const string SystemLibrary = "Kernel32.dll";
275
276 [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)]
277 public static extern IntPtr LoadLibrary(string lpFileName);
278
279 [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)]
280 public static extern IntPtr GetProcAddress(IntPtr hModule, string lpProcName);
281
282 [DllImport(SystemLibrary, SetLastError = true, CharSet = CharSet.Ansi)]
283 public static extern void FreeLibrary(IntPtr hModule);
284 }
285
286 private static class Android
287 {
288 public static IntPtr dlopen(string path) => dlopen(path, 1);
289
290#if UNITY_ANDROID
291 [DllImport("__Internal")]
292 public static extern IntPtr dlopen(string filename, int flags);
293
294 [DllImport("__Internal")]
295 public static extern IntPtr dlsym(IntPtr handle, string symbol);
296
297 [DllImport("__Internal")]
298 public static extern int dlclose(IntPtr handle);
299#else
300 public static IntPtr dlopen(string filename, int flags)
301 {
302 return default;
303 }
304
305 public static IntPtr dlsym(IntPtr handle, string symbol)
306 {
307 return default;
308 }
309
310 public static int dlclose(IntPtr handle)
311 {
312 return default;
313 }
314
315#endif
316 }
317 }
318
323 public class LLMLib
324 {
325 IntPtr libraryHandle = IntPtr.Zero;
326 static readonly object staticLock = new object();
327 static bool has_avx = false;
328 static bool has_avx2 = false;
329 static bool has_avx512 = false;
330 static bool has_avx_set = false;
331
332 static LLMLib()
333 {
334 lock (staticLock)
335 {
336 if (has_avx_set) return;
337 string archCheckerPath = GetArchitectureCheckerPath();
338 if (archCheckerPath != null)
339 {
340 IntPtr archCheckerHandle = LibraryLoader.LoadLibrary(archCheckerPath);
341 if (archCheckerHandle == IntPtr.Zero)
342 {
343 LLMUnitySetup.LogError($"Failed to load library {archCheckerPath}.");
344 }
345 else
346 {
347 try
348 {
349 has_avx = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle, "has_avx")();
350 has_avx2 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle, "has_avx2")();
351 has_avx512 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle, "has_avx512")();
352 LibraryLoader.FreeLibrary(archCheckerHandle);
353 }
354 catch (Exception e)
355 {
356 LLMUnitySetup.LogError($"{e.GetType()}: {e.Message}");
357 }
358 }
359 }
360 has_avx_set = true;
361 }
362 }
363
369 public LLMLib(string arch)
370 {
371 libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch));
372 if (libraryHandle == IntPtr.Zero)
373 {
374 throw new Exception($"Failed to load library {arch}.");
375 }
376
377 LLM_Construct = LibraryLoader.GetSymbolDelegate<LLM_ConstructDelegate>(libraryHandle, "LLM_Construct");
378 LLM_Delete = LibraryLoader.GetSymbolDelegate<LLM_DeleteDelegate>(libraryHandle, "LLM_Delete");
379 LLM_StartServer = LibraryLoader.GetSymbolDelegate<LLM_StartServerDelegate>(libraryHandle, "LLM_StartServer");
380 LLM_StopServer = LibraryLoader.GetSymbolDelegate<LLM_StopServerDelegate>(libraryHandle, "LLM_StopServer");
381 LLM_Start = LibraryLoader.GetSymbolDelegate<LLM_StartDelegate>(libraryHandle, "LLM_Start");
382 LLM_Started = LibraryLoader.GetSymbolDelegate<LLM_StartedDelegate>(libraryHandle, "LLM_Started");
383 LLM_Stop = LibraryLoader.GetSymbolDelegate<LLM_StopDelegate>(libraryHandle, "LLM_Stop");
384 LLM_SetTemplate = LibraryLoader.GetSymbolDelegate<LLM_SetTemplateDelegate>(libraryHandle, "LLM_SetTemplate");
385 LLM_SetSSL = LibraryLoader.GetSymbolDelegate<LLM_SetSSLDelegate>(libraryHandle, "LLM_SetSSL");
386 LLM_Tokenize = LibraryLoader.GetSymbolDelegate<LLM_TokenizeDelegate>(libraryHandle, "LLM_Tokenize");
387 LLM_Detokenize = LibraryLoader.GetSymbolDelegate<LLM_DetokenizeDelegate>(libraryHandle, "LLM_Detokenize");
388 LLM_Embeddings = LibraryLoader.GetSymbolDelegate<LLM_EmbeddingsDelegate>(libraryHandle, "LLM_Embeddings");
389 LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate<LLM_LoraWeightDelegate>(libraryHandle, "LLM_Lora_Weight");
390 LLM_LoraList = LibraryLoader.GetSymbolDelegate<LLM_LoraListDelegate>(libraryHandle, "LLM_Lora_List");
391 LLM_Completion = LibraryLoader.GetSymbolDelegate<LLM_CompletionDelegate>(libraryHandle, "LLM_Completion");
392 LLM_Slot = LibraryLoader.GetSymbolDelegate<LLM_SlotDelegate>(libraryHandle, "LLM_Slot");
393 LLM_Cancel = LibraryLoader.GetSymbolDelegate<LLM_CancelDelegate>(libraryHandle, "LLM_Cancel");
394 LLM_Status = LibraryLoader.GetSymbolDelegate<LLM_StatusDelegate>(libraryHandle, "LLM_Status");
395 StringWrapper_Construct = LibraryLoader.GetSymbolDelegate<StringWrapper_ConstructDelegate>(libraryHandle, "StringWrapper_Construct");
396 StringWrapper_Delete = LibraryLoader.GetSymbolDelegate<StringWrapper_DeleteDelegate>(libraryHandle, "StringWrapper_Delete");
397 StringWrapper_GetStringSize = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringSizeDelegate>(libraryHandle, "StringWrapper_GetStringSize");
398 StringWrapper_GetString = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringDelegate>(libraryHandle, "StringWrapper_GetString");
399 Logging = LibraryLoader.GetSymbolDelegate<LoggingDelegate>(libraryHandle, "Logging");
400 StopLogging = LibraryLoader.GetSymbolDelegate<StopLoggingDelegate>(libraryHandle, "StopLogging");
401 }
402
406 public void Destroy()
407 {
408 if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle);
409 }
410
416 public static List<string> PossibleArchitectures(bool gpu = false)
417 {
418 List<string> architectures = new List<string>();
419 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer ||
420 Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
421 {
422 if (gpu)
423 {
424 if (LLMUnitySetup.FullLlamaLib)
425 {
426 architectures.Add("cuda-cu12.2.0-full");
427 architectures.Add("cuda-cu11.7.1-full");
428 architectures.Add("hip-full");
429 }
430 else
431 {
432 architectures.Add("cuda-cu12.2.0");
433 architectures.Add("cuda-cu11.7.1");
434 architectures.Add("hip");
435 }
436 architectures.Add("vulkan");
437 }
438 if (has_avx512) architectures.Add("avx512");
439 if (has_avx2) architectures.Add("avx2");
440 if (has_avx) architectures.Add("avx");
441 architectures.Add("noavx");
442 }
443 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
444 {
445 string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower();
446 if (arch.Contains("arm"))
447 {
448 architectures.Add("arm64-acc");
449 architectures.Add("arm64-no_acc");
450 }
451 else
452 {
453 if (arch != "x86" && arch != "x64") LLMUnitySetup.LogWarning($"Unknown architecture of processor {arch}! Falling back to x86_64");
454 architectures.Add("x64-acc");
455 architectures.Add("x64-no_acc");
456 }
457 }
458 else if (Application.platform == RuntimePlatform.Android)
459 {
460 architectures.Add("android");
461 }
462 else
463 {
464 string error = "Unknown OS";
465 LLMUnitySetup.LogError(error);
466 throw new Exception(error);
467 }
468 return architectures;
469 }
470
475 public static string GetArchitectureCheckerPath()
476 {
477 string filename;
478 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
479 {
480 filename = $"windows-archchecker/archchecker.dll";
481 }
482 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
483 {
484 filename = $"linux-archchecker/libarchchecker.so";
485 }
486 else
487 {
488 return null;
489 }
490 return Path.Combine(LLMUnitySetup.libraryPath, filename);
491 }
492
498 public static string GetArchitecturePath(string arch)
499 {
500 string filename;
501 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
502 {
503 filename = $"windows-{arch}/undreamai_windows-{arch}.dll";
504 }
505 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
506 {
507 filename = $"linux-{arch}/libundreamai_linux-{arch}.so";
508 }
509 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
510 {
511 filename = $"macos-{arch}/libundreamai_macos-{arch}.dylib";
512 }
513 else if (Application.platform == RuntimePlatform.Android)
514 {
515 return "libundreamai_android.so";
516 }
517 else
518 {
519 string error = "Unknown OS";
520 LLMUnitySetup.LogError(error);
521 throw new Exception(error);
522 }
523 return Path.Combine(LLMUnitySetup.libraryPath, filename);
524 }
525
531 public string GetStringWrapperResult(IntPtr stringWrapper)
532 {
533 string result = "";
534 int bufferSize = StringWrapper_GetStringSize(stringWrapper);
535 if (bufferSize > 1)
536 {
537 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
538 try
539 {
540 StringWrapper_GetString(stringWrapper, buffer, bufferSize);
541 result = Marshal.PtrToStringAnsi(buffer);
542 }
543 finally
544 {
545 Marshal.FreeHGlobal(buffer);
546 }
547 }
548 return result;
549 }
550
551 public delegate bool HasArchDelegate();
552 public delegate void LoggingDelegate(IntPtr stringWrapper);
553 public delegate void StopLoggingDelegate();
554 public delegate IntPtr LLM_ConstructDelegate(string command);
555 public delegate void LLM_DeleteDelegate(IntPtr LLMObject);
556 public delegate void LLM_StartServerDelegate(IntPtr LLMObject);
557 public delegate void LLM_StopServerDelegate(IntPtr LLMObject);
558 public delegate void LLM_StartDelegate(IntPtr LLMObject);
559 public delegate bool LLM_StartedDelegate(IntPtr LLMObject);
560 public delegate void LLM_StopDelegate(IntPtr LLMObject);
561 public delegate void LLM_SetTemplateDelegate(IntPtr LLMObject, string chatTemplate);
562 public delegate void LLM_SetSSLDelegate(IntPtr LLMObject, string SSLCert, string SSLKey);
563 public delegate void LLM_TokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
564 public delegate void LLM_DetokenizeDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
565 public delegate void LLM_EmbeddingsDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
566 public delegate void LLM_LoraWeightDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
567 public delegate void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper);
568 public delegate void LLM_CompletionDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
569 public delegate void LLM_SlotDelegate(IntPtr LLMObject, string jsonData, IntPtr stringWrapper);
570 public delegate void LLM_CancelDelegate(IntPtr LLMObject, int idSlot);
571 public delegate int LLM_StatusDelegate(IntPtr LLMObject, IntPtr stringWrapper);
572 public delegate IntPtr StringWrapper_ConstructDelegate();
573 public delegate void StringWrapper_DeleteDelegate(IntPtr instance);
574 public delegate int StringWrapper_GetStringSizeDelegate(IntPtr instance);
575 public delegate void StringWrapper_GetStringDelegate(IntPtr instance, IntPtr buffer, int bufferSize, bool clear = false);
576
577 public LoggingDelegate Logging;
578 public StopLoggingDelegate StopLogging;
579 public LLM_ConstructDelegate LLM_Construct;
580 public LLM_DeleteDelegate LLM_Delete;
581 public LLM_StartServerDelegate LLM_StartServer;
582 public LLM_StopServerDelegate LLM_StopServer;
583 public LLM_StartDelegate LLM_Start;
584 public LLM_StartedDelegate LLM_Started;
585 public LLM_StopDelegate LLM_Stop;
586 public LLM_SetTemplateDelegate LLM_SetTemplate;
587 public LLM_SetSSLDelegate LLM_SetSSL;
588 public LLM_TokenizeDelegate LLM_Tokenize;
589 public LLM_DetokenizeDelegate LLM_Detokenize;
590 public LLM_CompletionDelegate LLM_Completion;
591 public LLM_EmbeddingsDelegate LLM_Embeddings;
592 public LLM_LoraWeightDelegate LLM_Lora_Weight;
593 public LLM_LoraListDelegate LLM_LoraList;
594 public LLM_SlotDelegate LLM_Slot;
595 public LLM_CancelDelegate LLM_Cancel;
596 public LLM_StatusDelegate LLM_Status;
597 public StringWrapper_ConstructDelegate StringWrapper_Construct;
598 public StringWrapper_DeleteDelegate StringWrapper_Delete;
599 public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize;
600 public StringWrapper_GetStringDelegate StringWrapper_GetString;
601 }
602}