5using System.Collections.Generic;
7using System.Runtime.InteropServices;
16 public class StreamWrapper
19 Callback<string> callback;
21 string previousString =
"";
22 string previousCalledString =
"";
23 int previousBufferSize = 0;
26 public StreamWrapper(LLMLib llmlib, Callback<string> callback,
bool clearOnUpdate =
false)
29 this.callback = callback;
30 this.clearOnUpdate = clearOnUpdate;
31 stringWrapper = (llmlib?.StringWrapper_Construct()).GetValueOrDefault();
39 public string GetString(
bool clear =
false)
42 int bufferSize = (llmlib?.StringWrapper_GetStringSize(stringWrapper)).GetValueOrDefault();
47 else if (previousBufferSize != bufferSize)
49 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
52 llmlib?.StringWrapper_GetString(stringWrapper, buffer, bufferSize, clear);
53 result = Marshal.PtrToStringAnsi(buffer);
57 Marshal.FreeHGlobal(buffer);
59 previousString = result;
63 result = previousString;
65 previousBufferSize = bufferSize;
74 if (stringWrapper == IntPtr.Zero)
return;
75 string result = GetString(clearOnUpdate);
76 if (result !=
"" && previousCalledString != result)
78 callback?.Invoke(result);
79 previousCalledString = result;
87 public IntPtr GetStringWrapper()
97 if (stringWrapper != IntPtr.Zero) llmlib?.StringWrapper_Delete(stringWrapper);
107 static class LibraryLoader
116 public static T GetSymbolDelegate<T>(IntPtr library,
string name) where T : Delegate
118 var symbol = GetSymbol(library, name);
119 if (symbol == IntPtr.Zero)
120 throw new EntryPointNotFoundException($
"Unable to load symbol '{name}'.");
122 return Marshal.GetDelegateForFunctionPointer<T>(symbol);
130 public static IntPtr LoadLibrary(
string libraryName)
132 if (
string.IsNullOrEmpty(libraryName))
133 throw new ArgumentNullException(nameof(libraryName));
136 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
137 handle = Win32.LoadLibrary(libraryName);
138 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
139 handle = Linux.dlopen(libraryName);
140 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
141 handle = Mac.dlopen(libraryName);
142 else if (Application.platform == RuntimePlatform.Android)
143 handle = Android.dlopen(libraryName);
145 throw new PlatformNotSupportedException($
"Current platform is unknown, unable to load library '{libraryName}'.");
156 public static IntPtr GetSymbol(IntPtr library,
string symbolName)
158 if (
string.IsNullOrEmpty(symbolName))
159 throw new ArgumentNullException(nameof(symbolName));
162 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
163 handle = Win32.GetProcAddress(library, symbolName);
164 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
165 handle = Linux.dlsym(library, symbolName);
166 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
167 handle = Mac.dlsym(library, symbolName);
168 else if (Application.platform == RuntimePlatform.Android)
169 handle = Android.dlsym(library, symbolName);
171 throw new PlatformNotSupportedException($
"Current platform is unknown, unable to load symbol '{symbolName}' from library {library}.");
180 public static void FreeLibrary(IntPtr library)
182 if (library == IntPtr.Zero)
185 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
186 Win32.FreeLibrary(library);
187 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
188 Linux.dlclose(library);
189 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
190 Mac.dlclose(library);
191 else if (Application.platform == RuntimePlatform.Android)
192 Android.dlclose(library);
194 throw new PlatformNotSupportedException($
"Current platform is unknown, unable to close library '{library}'.");
197 private static class Mac
199 private const string SystemLibrary =
"/usr/lib/libSystem.dylib";
201 private const int RTLD_LAZY = 1;
202 private const int RTLD_NOW = 2;
204 public static IntPtr dlopen(
string path,
bool lazy =
true) =>
205 dlopen(path, lazy ? RTLD_LAZY : RTLD_NOW);
207 [DllImport(SystemLibrary)]
208 public static extern IntPtr dlopen(
string path,
int mode);
210 [DllImport(SystemLibrary)]
211 public static extern IntPtr dlsym(IntPtr handle,
string symbol);
213 [DllImport(SystemLibrary)]
214 public static extern void dlclose(IntPtr handle);
217 private static class Linux
219 private const string SystemLibrary =
"libdl.so";
220 private const string SystemLibrary2 =
"libdl.so.2";
222 private const int RTLD_LAZY = 1;
223 private const int RTLD_NOW = 2;
225 private static bool UseSystemLibrary2 =
true;
227 public static IntPtr dlopen(
string path,
bool lazy =
true)
231 return dlopen2(path, lazy ? RTLD_LAZY : RTLD_NOW);
233 catch (DllNotFoundException)
235 UseSystemLibrary2 =
false;
236 return dlopen1(path, lazy ? RTLD_LAZY : RTLD_NOW);
240 public static IntPtr dlsym(IntPtr handle,
string symbol)
242 return UseSystemLibrary2 ? dlsym2(handle, symbol) : dlsym1(handle, symbol);
245 public static void dlclose(IntPtr handle)
247 if (UseSystemLibrary2)
253 [DllImport(SystemLibrary, EntryPoint =
"dlopen")]
254 private static extern IntPtr dlopen1(
string path,
int mode);
256 [DllImport(SystemLibrary, EntryPoint =
"dlsym")]
257 private static extern IntPtr dlsym1(IntPtr handle,
string symbol);
259 [DllImport(SystemLibrary, EntryPoint =
"dlclose")]
260 private static extern void dlclose1(IntPtr handle);
262 [DllImport(SystemLibrary2, EntryPoint =
"dlopen")]
263 private static extern IntPtr dlopen2(
string path,
int mode);
265 [DllImport(SystemLibrary2, EntryPoint =
"dlsym")]
266 private static extern IntPtr dlsym2(IntPtr handle,
string symbol);
268 [DllImport(SystemLibrary2, EntryPoint =
"dlclose")]
269 private static extern void dlclose2(IntPtr handle);
272 private static class Win32
274 private const string SystemLibrary =
"Kernel32.dll";
276 [DllImport(SystemLibrary, SetLastError =
true, CharSet = CharSet.Ansi)]
277 public static extern IntPtr LoadLibrary(
string lpFileName);
279 [DllImport(SystemLibrary, SetLastError =
true, CharSet = CharSet.Ansi)]
280 public static extern IntPtr GetProcAddress(IntPtr hModule,
string lpProcName);
282 [DllImport(SystemLibrary, SetLastError =
true, CharSet = CharSet.Ansi)]
283 public static extern void FreeLibrary(IntPtr hModule);
286 private static class Android
288 public static IntPtr dlopen(
string path) => dlopen(path, 1);
291 [DllImport(
"__Internal")]
292 public static extern IntPtr dlopen(
string filename,
int flags);
294 [DllImport(
"__Internal")]
295 public static extern IntPtr dlsym(IntPtr handle,
string symbol);
297 [DllImport(
"__Internal")]
298 public static extern int dlclose(IntPtr handle);
300 public static IntPtr dlopen(
string filename,
int flags)
305 public static IntPtr dlsym(IntPtr handle,
string symbol)
310 public static int dlclose(IntPtr handle)
325 IntPtr libraryHandle = IntPtr.Zero;
326 static readonly
object staticLock =
new object();
327 static bool has_avx =
false;
328 static bool has_avx2 =
false;
329 static bool has_avx512 =
false;
330 static bool has_avx_set =
false;
336 if (has_avx_set)
return;
337 string archCheckerPath = GetArchitectureCheckerPath();
338 if (archCheckerPath !=
null)
340 IntPtr archCheckerHandle = LibraryLoader.LoadLibrary(archCheckerPath);
341 if (archCheckerHandle == IntPtr.Zero)
343 LLMUnitySetup.LogError($
"Failed to load library {archCheckerPath}.");
349 has_avx = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle,
"has_avx")();
350 has_avx2 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle,
"has_avx2")();
351 has_avx512 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle,
"has_avx512")();
352 LibraryLoader.FreeLibrary(archCheckerHandle);
356 LLMUnitySetup.LogError($
"{e.GetType()}: {e.Message}");
369 public LLMLib(
string arch)
371 libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch));
372 if (libraryHandle == IntPtr.Zero)
374 throw new Exception($
"Failed to load library {arch}.");
377 LLM_Construct = LibraryLoader.GetSymbolDelegate<LLM_ConstructDelegate>(libraryHandle,
"LLM_Construct");
378 LLM_Delete = LibraryLoader.GetSymbolDelegate<LLM_DeleteDelegate>(libraryHandle,
"LLM_Delete");
379 LLM_StartServer = LibraryLoader.GetSymbolDelegate<LLM_StartServerDelegate>(libraryHandle,
"LLM_StartServer");
380 LLM_StopServer = LibraryLoader.GetSymbolDelegate<LLM_StopServerDelegate>(libraryHandle,
"LLM_StopServer");
381 LLM_Start = LibraryLoader.GetSymbolDelegate<LLM_StartDelegate>(libraryHandle,
"LLM_Start");
382 LLM_Started = LibraryLoader.GetSymbolDelegate<LLM_StartedDelegate>(libraryHandle,
"LLM_Started");
383 LLM_Stop = LibraryLoader.GetSymbolDelegate<LLM_StopDelegate>(libraryHandle,
"LLM_Stop");
384 LLM_SetTemplate = LibraryLoader.GetSymbolDelegate<LLM_SetTemplateDelegate>(libraryHandle,
"LLM_SetTemplate");
385 LLM_SetSSL = LibraryLoader.GetSymbolDelegate<LLM_SetSSLDelegate>(libraryHandle,
"LLM_SetSSL");
386 LLM_Tokenize = LibraryLoader.GetSymbolDelegate<LLM_TokenizeDelegate>(libraryHandle,
"LLM_Tokenize");
387 LLM_Detokenize = LibraryLoader.GetSymbolDelegate<LLM_DetokenizeDelegate>(libraryHandle,
"LLM_Detokenize");
388 LLM_Embeddings = LibraryLoader.GetSymbolDelegate<LLM_EmbeddingsDelegate>(libraryHandle,
"LLM_Embeddings");
389 LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate<LLM_LoraWeightDelegate>(libraryHandle,
"LLM_Lora_Weight");
390 LLM_LoraList = LibraryLoader.GetSymbolDelegate<LLM_LoraListDelegate>(libraryHandle,
"LLM_Lora_List");
391 LLM_Completion = LibraryLoader.GetSymbolDelegate<LLM_CompletionDelegate>(libraryHandle,
"LLM_Completion");
392 LLM_Slot = LibraryLoader.GetSymbolDelegate<LLM_SlotDelegate>(libraryHandle,
"LLM_Slot");
393 LLM_Cancel = LibraryLoader.GetSymbolDelegate<LLM_CancelDelegate>(libraryHandle,
"LLM_Cancel");
394 LLM_Status = LibraryLoader.GetSymbolDelegate<LLM_StatusDelegate>(libraryHandle,
"LLM_Status");
395 StringWrapper_Construct = LibraryLoader.GetSymbolDelegate<StringWrapper_ConstructDelegate>(libraryHandle,
"StringWrapper_Construct");
396 StringWrapper_Delete = LibraryLoader.GetSymbolDelegate<StringWrapper_DeleteDelegate>(libraryHandle,
"StringWrapper_Delete");
397 StringWrapper_GetStringSize = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringSizeDelegate>(libraryHandle,
"StringWrapper_GetStringSize");
398 StringWrapper_GetString = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringDelegate>(libraryHandle,
"StringWrapper_GetString");
399 Logging = LibraryLoader.GetSymbolDelegate<LoggingDelegate>(libraryHandle,
"Logging");
400 StopLogging = LibraryLoader.GetSymbolDelegate<StopLoggingDelegate>(libraryHandle,
"StopLogging");
406 public void Destroy()
408 if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle);
416 public static List<string> PossibleArchitectures(
bool gpu =
false)
418 List<string> architectures =
new List<string>();
419 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer ||
420 Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
424 if (LLMUnitySetup.FullLlamaLib)
426 architectures.Add(
"cuda-cu12.2.0-full");
427 architectures.Add(
"cuda-cu11.7.1-full");
428 architectures.Add(
"hip-full");
432 architectures.Add(
"cuda-cu12.2.0");
433 architectures.Add(
"cuda-cu11.7.1");
434 architectures.Add(
"hip");
436 architectures.Add(
"vulkan");
438 if (has_avx512) architectures.Add(
"avx512");
439 if (has_avx2) architectures.Add(
"avx2");
440 if (has_avx) architectures.Add(
"avx");
441 architectures.Add(
"noavx");
443 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
445 string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower();
446 if (arch.Contains(
"arm"))
448 architectures.Add(
"arm64-acc");
449 architectures.Add(
"arm64-no_acc");
453 if (arch !=
"x86" && arch !=
"x64") LLMUnitySetup.LogWarning($
"Unknown architecture of processor {arch}! Falling back to x86_64");
454 architectures.Add(
"x64-acc");
455 architectures.Add(
"x64-no_acc");
458 else if (Application.platform == RuntimePlatform.Android)
460 architectures.Add(
"android");
464 string error =
"Unknown OS";
465 LLMUnitySetup.LogError(error);
466 throw new Exception(error);
468 return architectures;
475 public static string GetArchitectureCheckerPath()
478 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
480 filename = $
"windows-archchecker/archchecker.dll";
482 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
484 filename = $
"linux-archchecker/libarchchecker.so";
490 return Path.Combine(LLMUnitySetup.libraryPath, filename);
498 public static string GetArchitecturePath(
string arch)
501 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
503 filename = $
"windows-{arch}/undreamai_windows-{arch}.dll";
505 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
507 filename = $
"linux-{arch}/libundreamai_linux-{arch}.so";
509 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
511 filename = $
"macos-{arch}/libundreamai_macos-{arch}.dylib";
513 else if (Application.platform == RuntimePlatform.Android)
515 return "libundreamai_android.so";
519 string error =
"Unknown OS";
520 LLMUnitySetup.LogError(error);
521 throw new Exception(error);
523 return Path.Combine(LLMUnitySetup.libraryPath, filename);
531 public string GetStringWrapperResult(IntPtr stringWrapper)
534 int bufferSize = StringWrapper_GetStringSize(stringWrapper);
537 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
540 StringWrapper_GetString(stringWrapper, buffer, bufferSize);
541 result = Marshal.PtrToStringAnsi(buffer);
545 Marshal.FreeHGlobal(buffer);
551 public delegate
bool HasArchDelegate();
552 public delegate
void LoggingDelegate(IntPtr stringWrapper);
553 public delegate
void StopLoggingDelegate();
554 public delegate IntPtr LLM_ConstructDelegate(
string command);
555 public delegate
void LLM_DeleteDelegate(IntPtr LLMObject);
556 public delegate
void LLM_StartServerDelegate(IntPtr LLMObject);
557 public delegate
void LLM_StopServerDelegate(IntPtr LLMObject);
558 public delegate
void LLM_StartDelegate(IntPtr LLMObject);
559 public delegate
bool LLM_StartedDelegate(IntPtr LLMObject);
560 public delegate
void LLM_StopDelegate(IntPtr LLMObject);
561 public delegate
void LLM_SetTemplateDelegate(IntPtr LLMObject,
string chatTemplate);
562 public delegate
void LLM_SetSSLDelegate(IntPtr LLMObject,
string SSLCert,
string SSLKey);
563 public delegate
void LLM_TokenizeDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
564 public delegate
void LLM_DetokenizeDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
565 public delegate
void LLM_EmbeddingsDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
566 public delegate
void LLM_LoraWeightDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
567 public delegate
void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper);
568 public delegate
void LLM_CompletionDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
569 public delegate
void LLM_SlotDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
570 public delegate
void LLM_CancelDelegate(IntPtr LLMObject,
int idSlot);
571 public delegate
int LLM_StatusDelegate(IntPtr LLMObject, IntPtr stringWrapper);
572 public delegate IntPtr StringWrapper_ConstructDelegate();
573 public delegate
void StringWrapper_DeleteDelegate(IntPtr instance);
574 public delegate
int StringWrapper_GetStringSizeDelegate(IntPtr instance);
575 public delegate
void StringWrapper_GetStringDelegate(IntPtr instance, IntPtr buffer,
int bufferSize,
bool clear =
false);
577 public LoggingDelegate Logging;
578 public StopLoggingDelegate StopLogging;
579 public LLM_ConstructDelegate LLM_Construct;
580 public LLM_DeleteDelegate LLM_Delete;
581 public LLM_StartServerDelegate LLM_StartServer;
582 public LLM_StopServerDelegate LLM_StopServer;
583 public LLM_StartDelegate LLM_Start;
584 public LLM_StartedDelegate LLM_Started;
585 public LLM_StopDelegate LLM_Stop;
586 public LLM_SetTemplateDelegate LLM_SetTemplate;
587 public LLM_SetSSLDelegate LLM_SetSSL;
588 public LLM_TokenizeDelegate LLM_Tokenize;
589 public LLM_DetokenizeDelegate LLM_Detokenize;
590 public LLM_CompletionDelegate LLM_Completion;
591 public LLM_EmbeddingsDelegate LLM_Embeddings;
592 public LLM_LoraWeightDelegate LLM_Lora_Weight;
593 public LLM_LoraListDelegate LLM_LoraList;
594 public LLM_SlotDelegate LLM_Slot;
595 public LLM_CancelDelegate LLM_Cancel;
596 public LLM_StatusDelegate LLM_Status;
597 public StringWrapper_ConstructDelegate StringWrapper_Construct;
598 public StringWrapper_DeleteDelegate StringWrapper_Delete;
599 public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize;
600 public StringWrapper_GetStringDelegate StringWrapper_GetString;