LLM for Unity  v2.3.0
Create characters in Unity with LLMs!
Loading...
Searching...
No Matches
LLMCharacter.cs
Go to the documentation of this file.
1
3using System;
4using System.Collections.Generic;
5using System.IO;
6using System.Threading;
7using System.Threading.Tasks;
8using UnityEditor;
9using UnityEngine;
10
11namespace LLMUnity
12{
13 [DefaultExecutionOrder(-2)]
18 public class LLMCharacter : LLMCaller
19 {
23 [LLM] public string save = "";
25 [LLM] public bool saveCache = false;
27 [LLM] public bool debugPrompt = false;
32 [Model] public int numPredict = 256;
34 [ModelAdvanced] public int slot = -1;
36 [ModelAdvanced] public string grammar = null;
38 [ModelAdvanced] public bool cachePrompt = true;
40 [ModelAdvanced] public int seed = 0;
45 [ModelAdvanced, Float(0f, 2f)] public float temperature = 0.2f;
48 [ModelAdvanced, Int(-1, 100)] public int topK = 40;
53 [ModelAdvanced, Float(0f, 1f)] public float topP = 0.9f;
56 [ModelAdvanced, Float(0f, 1f)] public float minP = 0.05f;
59 [ModelAdvanced, Float(0f, 2f)] public float repeatPenalty = 1.1f;
62 [ModelAdvanced, Float(0f, 1f)] public float presencePenalty = 0f;
65 [ModelAdvanced, Float(0f, 1f)] public float frequencyPenalty = 0f;
66
68 [ModelAdvanced, Float(0f, 1f)] public float tfsZ = 1f;
70 [ModelAdvanced, Float(0f, 1f)] public float typicalP = 1f;
72 [ModelAdvanced, Int(0, 2048)] public int repeatLastN = 64;
74 [ModelAdvanced] public bool penalizeNl = true;
77 [ModelAdvanced] public string penaltyPrompt;
79 [ModelAdvanced, Int(0, 2)] public int mirostat = 0;
81 [ModelAdvanced, Float(0f, 10f)] public float mirostatTau = 5f;
83 [ModelAdvanced, Float(0f, 1f)] public float mirostatEta = 0.1f;
85 [ModelAdvanced, Int(0, 10)] public int nProbs = 0;
87 [ModelAdvanced] public bool ignoreEos = false;
88
90 public int nKeep = -1;
92 public List<string> stop = new List<string>();
95 public Dictionary<int, string> logitBias = null;
96
99 [Chat] public bool stream = true;
101 [Chat] public string playerName = "user";
103 [Chat] public string AIName = "assistant";
105 [TextArea(5, 10), Chat] public string prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.";
107 public bool setNKeepToPrompt = true;
109 public List<ChatMessage> chat;
111 public string grammarString;
112
114 protected SemaphoreSlim chatLock = new SemaphoreSlim(1, 1);
115 protected string chatTemplate;
116 protected ChatTemplate template = null;
118
128 public override void Awake()
129 {
130 if (!enabled) return;
131 base.Awake();
132 if (!remote)
133 {
134 int slotFromServer = llm.Register(this);
135 if (slot == -1) slot = slotFromServer;
136 }
137 InitGrammar();
138 InitHistory();
139 }
140
141 protected override void OnValidate()
142 {
143 base.OnValidate();
144 if (llm != null && llm.parallelPrompts > -1 && (slot < -1 || slot >= llm.parallelPrompts)) LLMUnitySetup.LogError($"The slot needs to be between 0 and {llm.parallelPrompts-1}, or -1 to be automatically set");
145 }
146
147 protected override string NotValidLLMError()
148 {
149 return base.NotValidLLMError() + $", it is an embedding only model";
150 }
151
157 public override bool IsValidLLM(LLM llmSet)
158 {
159 return !llmSet.embeddingsOnly;
160 }
161
162 protected virtual void InitHistory()
163 {
164 InitPrompt();
165 _ = LoadHistory();
166 }
167
168 protected virtual async Task LoadHistory()
169 {
170 if (save == "" || !File.Exists(GetJsonSavePath(save))) return;
171 await chatLock.WaitAsync(); // Acquire the lock
172 try
173 {
174 await Load(save);
175 }
176 finally
177 {
178 chatLock.Release(); // Release the lock
179 }
180 }
181
182 protected virtual string GetSavePath(string filename)
183 {
184 return Path.Combine(Application.persistentDataPath, filename).Replace('\\', '/');
185 }
186
192 public virtual string GetJsonSavePath(string filename)
193 {
194 return GetSavePath(filename + ".json");
195 }
196
202 public virtual string GetCacheSavePath(string filename)
203 {
204 return GetSavePath(filename + ".cache");
205 }
206
207 protected virtual void InitPrompt(bool clearChat = true)
208 {
209 if (chat != null)
210 {
211 if (clearChat) chat.Clear();
212 }
213 else
214 {
215 chat = new List<ChatMessage>();
216 }
217 ChatMessage promptMessage = new ChatMessage { role = "system", content = prompt };
218 if (chat.Count == 0)
219 {
220 chat.Add(promptMessage);
221 }
222 else
223 {
224 chat[0] = promptMessage;
225 }
226 }
227
233 public virtual void SetPrompt(string newPrompt, bool clearChat = true)
234 {
236 nKeep = -1;
237 InitPrompt(clearChat);
238 }
239
240 protected virtual bool CheckTemplate()
241 {
242 if (template == null)
243 {
244 LLMUnitySetup.LogError("Template not set!");
245 return false;
246 }
247 return true;
248 }
249
250 protected virtual async Task<bool> InitNKeep()
251 {
252 if (setNKeepToPrompt && nKeep == -1)
253 {
254 if (!CheckTemplate()) return false;
255 string systemPrompt = template.ComputePrompt(new List<ChatMessage>(){chat[0]}, playerName, "", false);
256 List<int> tokens = await Tokenize(systemPrompt);
257 if (tokens == null) return false;
258 SetNKeep(tokens);
259 }
260 return true;
261 }
262
263 protected virtual void InitGrammar()
264 {
265 if (grammar != null && grammar != "")
266 {
267 grammarString = File.ReadAllText(LLMUnitySetup.GetAssetPath(grammar));
268 }
269 }
270
271 protected virtual void SetNKeep(List<int> tokens)
272 {
273 // set the tokens to keep
274 nKeep = tokens.Count;
275 }
276
281 public virtual async Task LoadTemplate()
282 {
283 string llmTemplate;
284 if (remote)
285 {
287 }
288 else
289 {
290 llmTemplate = llm.GetTemplate();
291 }
292 if (llmTemplate != chatTemplate)
293 {
294 chatTemplate = llmTemplate;
295 template = chatTemplate == null ? null : ChatTemplate.GetTemplate(chatTemplate);
296 nKeep = -1;
297 }
298 }
299
304 public virtual async void SetGrammar(string path)
305 {
306#if UNITY_EDITOR
307 if (!EditorApplication.isPlaying) path = LLMUnitySetup.AddAsset(path);
308#endif
309 await LLMUnitySetup.AndroidExtractAsset(path, true);
310 grammar = path;
311 InitGrammar();
312 }
313
314 protected virtual List<string> GetStopwords()
315 {
316 if (!CheckTemplate()) return null;
317 List<string> stopAll = new List<string>(template.GetStop(playerName, AIName));
318 if (stop != null) stopAll.AddRange(stop);
319 return stopAll;
320 }
321
322 protected virtual ChatRequest GenerateRequest(string prompt)
323 {
324 // setup the request struct
326 if (debugPrompt) LLMUnitySetup.Log(prompt);
336 chatRequest.stop = GetStopwords();
354 return chatRequest;
355 }
356
362 public virtual void AddMessage(string role, string content)
363 {
364 // add the question / answer to the chat list, update prompt
365 chat.Add(new ChatMessage { role = role, content = content });
366 }
367
372 public virtual void AddPlayerMessage(string content)
373 {
375 }
376
381 public virtual void AddAIMessage(string content)
382 {
384 }
385
386 protected virtual string ChatContent(ChatResult result)
387 {
388 // get content from a chat result received from the endpoint
389 return result.content.Trim();
390 }
391
392 protected virtual string MultiChatContent(MultiChatResult result)
393 {
394 // get content from a chat result received from the endpoint
395 string response = "";
396 foreach (ChatResult resultPart in result.data)
397 {
398 response += resultPart.content;
399 }
400 return response.Trim();
401 }
402
403 protected virtual string SlotContent(SlotResult result)
404 {
405 // get the tokens from a tokenize result received from the endpoint
406 return result.filename;
407 }
408
409 protected virtual string TemplateContent(TemplateResult result)
410 {
411 // get content from a char result received from the endpoint in open AI format
412 return result.template;
413 }
414
415 protected virtual async Task<string> CompletionRequest(string json, Callback<string> callback = null)
416 {
417 string result = "";
418 if (stream)
419 {
420 result = await PostRequest<MultiChatResult, string>(json, "completion", MultiChatContent, callback);
421 }
422 else
423 {
424 result = await PostRequest<ChatResult, string>(json, "completion", ChatContent, callback);
425 }
426 return result;
427 }
428
441 {
442 // handle a chat message by the user
443 // call the callback function while the answer is received
444 // call the completionCallback function when the answer is fully received
446 if (!CheckTemplate()) return null;
447 if (!await InitNKeep()) return null;
448
449 string json;
450 await chatLock.WaitAsync();
451 try
452 {
454 string prompt = template.ComputePrompt(chat, playerName, AIName);
455 json = JsonUtility.ToJson(GenerateRequest(prompt));
456 chat.RemoveAt(chat.Count - 1);
457 }
458 finally
459 {
460 chatLock.Release();
461 }
462
463 string result = await CompletionRequest(json, callback);
464
465 if (addToHistory && result != null)
466 {
467 await chatLock.WaitAsync();
468 try
469 {
472 }
473 finally
474 {
475 chatLock.Release();
476 }
477 if (save != "") _ = Save(save);
478 }
479
480 completionCallback?.Invoke();
481 return result;
482 }
483
494 {
495 // handle a completion request by the user
496 // call the callback function while the answer is received
497 // call the completionCallback function when the answer is fully received
499
500 string json = JsonUtility.ToJson(GenerateRequest(prompt));
501 string result = await CompletionRequest(json, callback);
502 completionCallback?.Invoke();
503 return result;
504 }
505
517 {
519 if (!CheckTemplate()) return;
520 if (!await InitNKeep()) return;
521
522 string prompt = template.ComputePrompt(chat, playerName, AIName);
523 ChatRequest request = GenerateRequest(prompt);
525 string json = JsonUtility.ToJson(request);
526 await CompletionRequest(json);
527 completionCallback?.Invoke();
528 }
529
535 {
536 return await PostRequest<TemplateResult, string>("{}", "template", TemplateContent);
537 }
538
539 protected override void CancelRequestsLocal()
540 {
541 if (slot >= 0) llm.CancelRequest(slot);
542 }
543
544 protected virtual async Task<string> Slot(string filepath, string action)
545 {
550 string json = JsonUtility.ToJson(slotRequest);
551 return await PostRequest<SlotResult, string>(json, "slots", SlotContent);
552 }
553
559 public virtual async Task<string> Save(string filename)
560 {
561 string filepath = GetJsonSavePath(filename);
562 string dirname = Path.GetDirectoryName(filepath);
563 if (!Directory.Exists(dirname)) Directory.CreateDirectory(dirname);
564 string json = JsonUtility.ToJson(new ChatListWrapper { chat = chat.GetRange(1, chat.Count - 1) });
565 File.WriteAllText(filepath, json);
566
567 string cachepath = GetCacheSavePath(filename);
568 if (remote || !saveCache) return null;
569 string result = await Slot(cachepath, "save");
570 return result;
571 }
572
578 public virtual async Task<string> Load(string filename)
579 {
580 string filepath = GetJsonSavePath(filename);
581 if (!File.Exists(filepath))
582 {
583 LLMUnitySetup.LogError($"File {filepath} does not exist.");
584 return null;
585 }
586 string json = File.ReadAllText(filepath);
588 InitPrompt(true);
589 chat.AddRange(chatHistory);
590 LLMUnitySetup.Log($"Loaded {filepath}");
591
592 string cachepath = GetCacheSavePath(filename);
593 if (remote || !saveCache || !File.Exists(GetSavePath(cachepath))) return null;
594 string result = await Slot(cachepath, "restore");
595 return result;
596 }
597
598 protected override async Task<Ret> PostRequestLocal<Res, Ret>(string json, string endpoint, ContentCallback<Res, Ret> getContent, Callback<Ret> callback = null)
599 {
600 if (endpoint != "completion") return await base.PostRequestLocal(json, endpoint, getContent, callback);
601
602 while (!llm.failed && !llm.started) await Task.Yield();
603
604 string callResult = null;
605 bool callbackCalled = false;
606 if (llm.embeddingsOnly) LLMUnitySetup.LogError("The LLM can't be used for completion, only for embeddings");
607 else
608 {
610 if (stream && callback != null)
611 {
612 if (typeof(Ret) == typeof(string))
613 {
615 {
617 };
618 }
619 else
620 {
621 LLMUnitySetup.LogError($"wrong callback type, should be string");
622 }
623 callbackCalled = true;
624 }
626 }
627
629 if (!callbackCalled) callback?.Invoke(result);
630 return result;
631 }
632 }
633
635 [Serializable]
636 public class ChatListWrapper
637 {
638 public List<ChatMessage> chat;
639 }
641}
Class implementing the skeleton of a chat template.
static ChatTemplate GetTemplate(string template)
Creates the chat template based on the provided chat template name.
Class implementing calling of LLM functions (local and remote).
Definition LLMCaller.cs:17
virtual async Task< List< int > > Tokenize(string query, Callback< List< int > > callback=null)
Tokenises the provided query.
Definition LLMCaller.cs:337
bool remote
toggle to use remote LLM server or local LLM
Definition LLMCaller.cs:21
Class implementing the LLM characters.
bool cachePrompt
option to cache the prompt as it is being created by the chat to avoid reprocessing the entire prompt...
int slot
specify which slot of the server to use for computation (affects caching)
virtual async Task< string > Chat(string query, Callback< string > callback=null, EmptyCallback completionCallback=null, bool addToHistory=true)
Chat functionality of the LLM. It calls the LLM completion based on the provided query including the ...
List< string > stop
stopwords to stop the LLM in addition to the default stopwords from the chat template.
float topP
top-p sampling (1.0 = disabled). The top p value controls the cumulative probability of generated tok...
virtual async Task LoadTemplate()
Loads the chat template of the LLMCharacter.
string AIName
the name of the AI
int nProbs
if greater than 0, the response also contains the probabilities of top N tokens for each generated to...
bool ignoreEos
ignore end of stream token and continue generating.
float mirostatTau
set the Mirostat target entropy, parameter tau.
int numPredict
number of tokens to predict (-1 = infinity, -2 = until context filled). This is the amount of tokens ...
string prompt
a description of the AI role. This defines the LLMCharacter system prompt
float temperature
LLM temperature, lower values give more deterministic answers. The temperature setting adjusts how ra...
override void Awake()
The Unity Awake function that initializes the state before the application starts....
float presencePenalty
repeated token presence penalty (0.0 = disabled). Positive values penalize new tokens based on whethe...
virtual async Task< string > AskTemplate()
Asks the LLM for the chat template to use.
string playerName
the name of the player
virtual string GetCacheSavePath(string filename)
Allows to get the save path of the LLM cache based on the provided filename or relative path.
float mirostatEta
set the Mirostat learning rate, parameter eta.
float minP
minimum probability for a token to be used. The probability is defined relative to the probability of...
float typicalP
enable locally typical sampling with parameter p (1.0 = disabled).
virtual async Task< string > Load(string filename)
Load the chat history and cache from the provided filename / relative path.
int nKeep
number of tokens to retain from the prompt when the model runs out of context (-1 = LLMCharacter prom...
virtual async void SetGrammar(string path)
Sets the grammar file of the LLMCharacter.
bool penalizeNl
penalize newline tokens when applying the repeat penalty.
bool debugPrompt
select to log the constructed prompt the Unity Editor.
bool saveCache
toggle to save the LLM cache. This speeds up the prompt calculation but also requires ~100MB of space...
int topK
top-k sampling (0 = disabled). The top k value controls the top k most probable tokens at each step o...
string grammar
grammar file used for the LLM in .cbnf format (relative to the Assets/StreamingAssets folder)
virtual async Task< string > Save(string filename)
Saves the chat history and cache to the provided filename / relative path.
virtual void AddAIMessage(string content)
Allows to add a AI message in the chat history.
string penaltyPrompt
prompt for the purpose of the penalty evaluation. Can be either null, a string or an array of numbers...
override bool IsValidLLM(LLM llmSet)
Checks if a LLM is valid for the LLMCaller.
virtual void AddPlayerMessage(string content)
Allows to add a player message in the chat history.
float repeatPenalty
control the repetition of token sequences in the generated text. The penalty is applied to repeated t...
virtual string GetJsonSavePath(string filename)
Allows to get the save path of the chat history based on the provided filename or relative path.
bool stream
option to receive the reply from the model as it is produced (recommended!). If it is not selected,...
float frequencyPenalty
repeated token frequency penalty (0.0 = disabled). Positive values penalize new tokens based on their...
virtual async Task Warmup(EmptyCallback completionCallback=null)
Allow to warm-up a model by processing the prompt. The prompt processing will be cached (if cacheProm...
int mirostat
enable Mirostat sampling, controlling perplexity during text generation (0 = disabled,...
bool setNKeepToPrompt
option to set the number of tokens to retain from the prompt (nKeep) based on the LLMCharacter system...
int repeatLastN
last n tokens to consider for penalizing repetition (0 = disabled, -1 = ctx-size).
int seed
seed for reproducibility. For random results every time set to -1.
float tfsZ
enable tail free sampling with parameter z (1.0 = disabled).
string grammarString
the grammar to use
Dictionary< int, string > logitBias
the logit bias option allows to manually adjust the likelihood of specific tokens appearing in the ge...
virtual async Task< string > Complete(string prompt, Callback< string > callback=null, EmptyCallback completionCallback=null)
Pure completion functionality of the LLM. It calls the LLM completion based solely on the provided pr...
virtual void AddMessage(string role, string content)
Allows to add a message in the chat history.
List< ChatMessage > chat
the chat history as list of chat messages
string save
file to save the chat history. The file is saved only for Chat calls with addToHistory set to true....
virtual void SetPrompt(string newPrompt, bool clearChat=true)
Set the system prompt for the LLMCharacter.
Class implementing helper functions for setup and process management.
Class implementing the LLM server.
Definition LLM.cs:19
string GetTemplate()
Returns the chat template of the LLM.
Definition LLM.cs:387
void CancelRequest(int id_slot)
Allows to cancel the requests in a specific slot of the LLM.
Definition LLM.cs:791
int parallelPrompts
number of prompts that can happen in parallel (-1 = number of LLMCaller objects)
Definition LLM.cs:35
bool started
Boolean set to true if the server has started and is ready to receive requests, false otherwise.
Definition LLM.cs:46
async Task< string > Completion(string json, Callback< string > streamCallback=null)
Allows to use the chat and completion functionality of the LLM.
Definition LLM.cs:766
int Register(LLMCaller llmCaller)
Registers a local LLMCaller object. This allows to bind the LLMCaller "client" to a specific slot of ...
Definition LLM.cs:557
bool failed
Boolean set to true if the server has failed to start.
Definition LLM.cs:48