5using System.Collections.Generic;
7using System.Runtime.InteropServices;
12 public class StreamWrapper
15 Callback<string> callback;
17 string previousString =
"";
18 string previousCalledString =
"";
19 int previousBufferSize = 0;
22 public StreamWrapper(LLMLib llmlib, Callback<string> callback,
bool clearOnUpdate =
false)
25 this.callback = callback;
26 this.clearOnUpdate = clearOnUpdate;
27 stringWrapper = (llmlib?.StringWrapper_Construct()).GetValueOrDefault();
30 public string GetString(
bool clear =
false)
33 int bufferSize = (llmlib?.StringWrapper_GetStringSize(stringWrapper)).GetValueOrDefault();
38 else if (previousBufferSize != bufferSize)
40 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
43 llmlib?.StringWrapper_GetString(stringWrapper, buffer, bufferSize, clear);
44 result = Marshal.PtrToStringAnsi(buffer);
48 Marshal.FreeHGlobal(buffer);
50 previousString = result;
54 result = previousString;
56 previousBufferSize = bufferSize;
62 if (stringWrapper == IntPtr.Zero)
return;
63 string result = GetString(clearOnUpdate);
64 if (result !=
"" && previousCalledString != result)
66 callback?.Invoke(result);
67 previousCalledString = result;
71 public IntPtr GetStringWrapper()
78 if (stringWrapper != IntPtr.Zero) llmlib?.StringWrapper_Delete(stringWrapper);
82 static class LibraryLoader
86 public static T GetSymbolDelegate<T>(IntPtr library,
string name) where T : Delegate
88 var symbol = GetSymbol(library, name);
89 if (symbol == IntPtr.Zero)
90 throw new EntryPointNotFoundException($
"Unable to load symbol '{name}'.");
92 return Marshal.GetDelegateForFunctionPointer<T>(symbol);
95 public static IntPtr LoadLibrary(
string libraryName)
97 if (
string.IsNullOrEmpty(libraryName))
98 throw new ArgumentNullException(nameof(libraryName));
101 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
102 handle = Win32.LoadLibrary(libraryName);
103 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
104 handle = Linux.dlopen(libraryName);
105 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
106 handle = Mac.dlopen(libraryName);
107 else if (Application.platform == RuntimePlatform.Android)
108 handle = Android.dlopen(libraryName);
110 throw new PlatformNotSupportedException($
"Current platform is unknown, unable to load library '{libraryName}'.");
115 public static IntPtr GetSymbol(IntPtr library,
string symbolName)
117 if (
string.IsNullOrEmpty(symbolName))
118 throw new ArgumentNullException(nameof(symbolName));
121 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
122 handle = Win32.GetProcAddress(library, symbolName);
123 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
124 handle = Linux.dlsym(library, symbolName);
125 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
126 handle = Mac.dlsym(library, symbolName);
127 else if (Application.platform == RuntimePlatform.Android)
128 handle = Android.dlsym(library, symbolName);
130 throw new PlatformNotSupportedException($
"Current platform is unknown, unable to load symbol '{symbolName}' from library {library}.");
135 public static void FreeLibrary(IntPtr library)
137 if (library == IntPtr.Zero)
140 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
141 Win32.FreeLibrary(library);
142 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
143 Linux.dlclose(library);
144 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
145 Mac.dlclose(library);
146 else if (Application.platform == RuntimePlatform.Android)
147 Android.dlclose(library);
149 throw new PlatformNotSupportedException($
"Current platform is unknown, unable to close library '{library}'.");
152 private static class Mac
154 private const string SystemLibrary =
"/usr/lib/libSystem.dylib";
156 private const int RTLD_LAZY = 1;
157 private const int RTLD_NOW = 2;
159 public static IntPtr dlopen(
string path,
bool lazy =
true) =>
160 dlopen(path, lazy ? RTLD_LAZY : RTLD_NOW);
162 [DllImport(SystemLibrary)]
163 public static extern IntPtr dlopen(
string path,
int mode);
165 [DllImport(SystemLibrary)]
166 public static extern IntPtr dlsym(IntPtr handle,
string symbol);
168 [DllImport(SystemLibrary)]
169 public static extern void dlclose(IntPtr handle);
172 private static class Linux
174 private const string SystemLibrary =
"libdl.so";
175 private const string SystemLibrary2 =
"libdl.so.2";
177 private const int RTLD_LAZY = 1;
178 private const int RTLD_NOW = 2;
180 private static bool UseSystemLibrary2 =
true;
182 public static IntPtr dlopen(
string path,
bool lazy =
true)
186 return dlopen2(path, lazy ? RTLD_LAZY : RTLD_NOW);
188 catch (DllNotFoundException)
190 UseSystemLibrary2 =
false;
191 return dlopen1(path, lazy ? RTLD_LAZY : RTLD_NOW);
195 public static IntPtr dlsym(IntPtr handle,
string symbol)
197 return UseSystemLibrary2 ? dlsym2(handle, symbol) : dlsym1(handle, symbol);
200 public static void dlclose(IntPtr handle)
202 if (UseSystemLibrary2)
208 [DllImport(SystemLibrary, EntryPoint =
"dlopen")]
209 private static extern IntPtr dlopen1(
string path,
int mode);
211 [DllImport(SystemLibrary, EntryPoint =
"dlsym")]
212 private static extern IntPtr dlsym1(IntPtr handle,
string symbol);
214 [DllImport(SystemLibrary, EntryPoint =
"dlclose")]
215 private static extern void dlclose1(IntPtr handle);
217 [DllImport(SystemLibrary2, EntryPoint =
"dlopen")]
218 private static extern IntPtr dlopen2(
string path,
int mode);
220 [DllImport(SystemLibrary2, EntryPoint =
"dlsym")]
221 private static extern IntPtr dlsym2(IntPtr handle,
string symbol);
223 [DllImport(SystemLibrary2, EntryPoint =
"dlclose")]
224 private static extern void dlclose2(IntPtr handle);
227 private static class Win32
229 private const string SystemLibrary =
"Kernel32.dll";
231 [DllImport(SystemLibrary, SetLastError =
true, CharSet = CharSet.Ansi)]
232 public static extern IntPtr LoadLibrary(
string lpFileName);
234 [DllImport(SystemLibrary, SetLastError =
true, CharSet = CharSet.Ansi)]
235 public static extern IntPtr GetProcAddress(IntPtr hModule,
string lpProcName);
237 [DllImport(SystemLibrary, SetLastError =
true, CharSet = CharSet.Ansi)]
238 public static extern void FreeLibrary(IntPtr hModule);
241 private static class Android
243 public static IntPtr dlopen(
string path) => dlopen(path, 1);
247 [DllImport(
"__Internal")]
248 public static extern IntPtr dlopen(
string filename,
int flags);
251 [DllImport(
"__Internal")]
252 public static extern IntPtr dlsym(IntPtr handle,
string symbol);
255 [DllImport(
"__Internal")]
256 public static extern int dlclose(IntPtr handle);
258 public static IntPtr dlopen(
string filename,
int flags)
263 public static IntPtr dlsym(IntPtr handle,
string symbol)
268 public static int dlclose(IntPtr handle)
279 IntPtr libraryHandle = IntPtr.Zero;
280 static readonly
object staticLock =
new object();
281 static bool has_avx =
false;
282 static bool has_avx2 =
false;
283 static bool has_avx512 =
false;
284 static bool has_avx_set =
false;
290 if (has_avx_set)
return;
291 string archCheckerPath = GetArchitectureCheckerPath();
292 if (archCheckerPath !=
null)
294 IntPtr archCheckerHandle = LibraryLoader.LoadLibrary(archCheckerPath);
295 if (archCheckerHandle == IntPtr.Zero)
297 LLMUnitySetup.LogError($
"Failed to load library {archCheckerPath}.");
303 has_avx = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle,
"has_avx")();
304 has_avx2 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle,
"has_avx2")();
305 has_avx512 = LibraryLoader.GetSymbolDelegate<HasArchDelegate>(archCheckerHandle,
"has_avx512")();
306 LibraryLoader.FreeLibrary(archCheckerHandle);
310 LLMUnitySetup.LogError($
"{e.GetType()}: {e.Message}");
318 public LLMLib(
string arch)
320 libraryHandle = LibraryLoader.LoadLibrary(GetArchitecturePath(arch));
321 if (libraryHandle == IntPtr.Zero)
323 throw new Exception($
"Failed to load library {arch}.");
326 LLM_Construct = LibraryLoader.GetSymbolDelegate<LLM_ConstructDelegate>(libraryHandle,
"LLM_Construct");
327 LLM_Delete = LibraryLoader.GetSymbolDelegate<LLM_DeleteDelegate>(libraryHandle,
"LLM_Delete");
328 LLM_StartServer = LibraryLoader.GetSymbolDelegate<LLM_StartServerDelegate>(libraryHandle,
"LLM_StartServer");
329 LLM_StopServer = LibraryLoader.GetSymbolDelegate<LLM_StopServerDelegate>(libraryHandle,
"LLM_StopServer");
330 LLM_Start = LibraryLoader.GetSymbolDelegate<LLM_StartDelegate>(libraryHandle,
"LLM_Start");
331 LLM_Started = LibraryLoader.GetSymbolDelegate<LLM_StartedDelegate>(libraryHandle,
"LLM_Started");
332 LLM_Stop = LibraryLoader.GetSymbolDelegate<LLM_StopDelegate>(libraryHandle,
"LLM_Stop");
333 LLM_SetTemplate = LibraryLoader.GetSymbolDelegate<LLM_SetTemplateDelegate>(libraryHandle,
"LLM_SetTemplate");
334 LLM_SetSSL = LibraryLoader.GetSymbolDelegate<LLM_SetSSLDelegate>(libraryHandle,
"LLM_SetSSL");
335 LLM_Tokenize = LibraryLoader.GetSymbolDelegate<LLM_TokenizeDelegate>(libraryHandle,
"LLM_Tokenize");
336 LLM_Detokenize = LibraryLoader.GetSymbolDelegate<LLM_DetokenizeDelegate>(libraryHandle,
"LLM_Detokenize");
337 LLM_Embeddings = LibraryLoader.GetSymbolDelegate<LLM_EmbeddingsDelegate>(libraryHandle,
"LLM_Embeddings");
338 LLM_Lora_Weight = LibraryLoader.GetSymbolDelegate<LLM_LoraWeightDelegate>(libraryHandle,
"LLM_Lora_Weight");
339 LLM_LoraList = LibraryLoader.GetSymbolDelegate<LLM_LoraListDelegate>(libraryHandle,
"LLM_Lora_List");
340 LLM_Completion = LibraryLoader.GetSymbolDelegate<LLM_CompletionDelegate>(libraryHandle,
"LLM_Completion");
341 LLM_Slot = LibraryLoader.GetSymbolDelegate<LLM_SlotDelegate>(libraryHandle,
"LLM_Slot");
342 LLM_Cancel = LibraryLoader.GetSymbolDelegate<LLM_CancelDelegate>(libraryHandle,
"LLM_Cancel");
343 LLM_Status = LibraryLoader.GetSymbolDelegate<LLM_StatusDelegate>(libraryHandle,
"LLM_Status");
344 StringWrapper_Construct = LibraryLoader.GetSymbolDelegate<StringWrapper_ConstructDelegate>(libraryHandle,
"StringWrapper_Construct");
345 StringWrapper_Delete = LibraryLoader.GetSymbolDelegate<StringWrapper_DeleteDelegate>(libraryHandle,
"StringWrapper_Delete");
346 StringWrapper_GetStringSize = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringSizeDelegate>(libraryHandle,
"StringWrapper_GetStringSize");
347 StringWrapper_GetString = LibraryLoader.GetSymbolDelegate<StringWrapper_GetStringDelegate>(libraryHandle,
"StringWrapper_GetString");
348 Logging = LibraryLoader.GetSymbolDelegate<LoggingDelegate>(libraryHandle,
"Logging");
349 StopLogging = LibraryLoader.GetSymbolDelegate<StopLoggingDelegate>(libraryHandle,
"StopLogging");
352 public void Destroy()
354 if (libraryHandle != IntPtr.Zero) LibraryLoader.FreeLibrary(libraryHandle);
357 public static List<string> PossibleArchitectures(
bool gpu =
false)
359 List<string> architectures =
new List<string>();
360 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer ||
361 Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
365 if (LLMUnitySetup.FullLlamaLib)
367 architectures.Add(
"cuda-cu12.2.0-full");
368 architectures.Add(
"cuda-cu11.7.1-full");
369 architectures.Add(
"hip-full");
373 architectures.Add(
"cuda-cu12.2.0");
374 architectures.Add(
"cuda-cu11.7.1");
375 architectures.Add(
"hip");
377 architectures.Add(
"vulkan");
379 if (has_avx512) architectures.Add(
"avx512");
380 if (has_avx2) architectures.Add(
"avx2");
381 if (has_avx) architectures.Add(
"avx");
382 architectures.Add(
"noavx");
384 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
386 string arch = RuntimeInformation.ProcessArchitecture.ToString().ToLower();
387 if (arch.Contains(
"arm"))
389 architectures.Add(
"arm64-acc");
390 architectures.Add(
"arm64-no_acc");
394 if (arch !=
"x86" && arch !=
"x64") LLMUnitySetup.LogWarning($
"Unknown architecture of processor {arch}! Falling back to x86_64");
395 architectures.Add(
"x64-acc");
396 architectures.Add(
"x64-no_acc");
399 else if (Application.platform == RuntimePlatform.Android)
401 architectures.Add(
"android");
405 string error =
"Unknown OS";
406 LLMUnitySetup.LogError(error);
407 throw new Exception(error);
409 return architectures;
412 public static string GetArchitectureCheckerPath()
415 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
417 filename = $
"windows-archchecker/archchecker.dll";
419 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
421 filename = $
"linux-archchecker/libarchchecker.so";
427 return Path.Combine(LLMUnitySetup.libraryPath, filename);
430 public static string GetArchitecturePath(
string arch)
433 if (Application.platform == RuntimePlatform.WindowsEditor || Application.platform == RuntimePlatform.WindowsPlayer)
435 filename = $
"windows-{arch}/undreamai_windows-{arch}.dll";
437 else if (Application.platform == RuntimePlatform.LinuxEditor || Application.platform == RuntimePlatform.LinuxPlayer)
439 filename = $
"linux-{arch}/libundreamai_linux-{arch}.so";
441 else if (Application.platform == RuntimePlatform.OSXEditor || Application.platform == RuntimePlatform.OSXPlayer)
443 filename = $
"macos-{arch}/libundreamai_macos-{arch}.dylib";
445 else if (Application.platform == RuntimePlatform.Android)
447 return "libundreamai_android.so";
451 string error =
"Unknown OS";
452 LLMUnitySetup.LogError(error);
453 throw new Exception(error);
455 return Path.Combine(LLMUnitySetup.libraryPath, filename);
458 public string GetStringWrapperResult(IntPtr stringWrapper)
461 int bufferSize = StringWrapper_GetStringSize(stringWrapper);
464 IntPtr buffer = Marshal.AllocHGlobal(bufferSize);
467 StringWrapper_GetString(stringWrapper, buffer, bufferSize);
468 result = Marshal.PtrToStringAnsi(buffer);
472 Marshal.FreeHGlobal(buffer);
478 public delegate
bool HasArchDelegate();
479 public delegate
void LoggingDelegate(IntPtr stringWrapper);
480 public delegate
void StopLoggingDelegate();
481 public delegate IntPtr LLM_ConstructDelegate(
string command);
482 public delegate
void LLM_DeleteDelegate(IntPtr LLMObject);
483 public delegate
void LLM_StartServerDelegate(IntPtr LLMObject);
484 public delegate
void LLM_StopServerDelegate(IntPtr LLMObject);
485 public delegate
void LLM_StartDelegate(IntPtr LLMObject);
486 public delegate
bool LLM_StartedDelegate(IntPtr LLMObject);
487 public delegate
void LLM_StopDelegate(IntPtr LLMObject);
488 public delegate
void LLM_SetTemplateDelegate(IntPtr LLMObject,
string chatTemplate);
489 public delegate
void LLM_SetSSLDelegate(IntPtr LLMObject,
string SSLCert,
string SSLKey);
490 public delegate
void LLM_TokenizeDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
491 public delegate
void LLM_DetokenizeDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
492 public delegate
void LLM_EmbeddingsDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
493 public delegate
void LLM_LoraWeightDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
494 public delegate
void LLM_LoraListDelegate(IntPtr LLMObject, IntPtr stringWrapper);
495 public delegate
void LLM_CompletionDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
496 public delegate
void LLM_SlotDelegate(IntPtr LLMObject,
string jsonData, IntPtr stringWrapper);
497 public delegate
void LLM_CancelDelegate(IntPtr LLMObject,
int idSlot);
498 public delegate
int LLM_StatusDelegate(IntPtr LLMObject, IntPtr stringWrapper);
499 public delegate IntPtr StringWrapper_ConstructDelegate();
500 public delegate
void StringWrapper_DeleteDelegate(IntPtr instance);
501 public delegate
int StringWrapper_GetStringSizeDelegate(IntPtr instance);
502 public delegate
void StringWrapper_GetStringDelegate(IntPtr instance, IntPtr buffer,
int bufferSize,
bool clear =
false);
504 public LoggingDelegate Logging;
505 public StopLoggingDelegate StopLogging;
506 public LLM_ConstructDelegate LLM_Construct;
507 public LLM_DeleteDelegate LLM_Delete;
508 public LLM_StartServerDelegate LLM_StartServer;
509 public LLM_StopServerDelegate LLM_StopServer;
510 public LLM_StartDelegate LLM_Start;
511 public LLM_StartedDelegate LLM_Started;
512 public LLM_StopDelegate LLM_Stop;
513 public LLM_SetTemplateDelegate LLM_SetTemplate;
514 public LLM_SetSSLDelegate LLM_SetSSL;
515 public LLM_TokenizeDelegate LLM_Tokenize;
516 public LLM_DetokenizeDelegate LLM_Detokenize;
517 public LLM_CompletionDelegate LLM_Completion;
518 public LLM_EmbeddingsDelegate LLM_Embeddings;
519 public LLM_LoraWeightDelegate LLM_Lora_Weight;
520 public LLM_LoraListDelegate LLM_LoraList;
521 public LLM_SlotDelegate LLM_Slot;
522 public LLM_CancelDelegate LLM_Cancel;
523 public LLM_StatusDelegate LLM_Status;
524 public StringWrapper_ConstructDelegate StringWrapper_Construct;
525 public StringWrapper_DeleteDelegate StringWrapper_Delete;
526 public StringWrapper_GetStringSizeDelegate StringWrapper_GetStringSize;
527 public StringWrapper_GetStringDelegate StringWrapper_GetString;