338 lines
9.5 KiB
C++
338 lines
9.5 KiB
C++
// Conversation.cpp
|
|
#include "Conversation.h"
|
|
|
|
#include <Alert.h>
|
|
#include <File.h>
|
|
#include <FindDirectory.h>
|
|
#include <Path.h>
|
|
#include <String.h>
|
|
|
|
#include <MimeType.h>
|
|
#include <Url.h>
|
|
|
|
#include <Application.h>
|
|
|
|
Conversation::Conversation(BHandler *replyTo) {
|
|
|
|
replyTarget = replyTo;
|
|
_apiKey = ReadOpenAIKey();
|
|
printf("key is: %s", _apiKey.String());
|
|
}
|
|
Conversation::~Conversation() {}
|
|
|
|
void Conversation::PrintAsJsonArray(const std::vector<std::string> &models) {
|
|
json output = models; // implicit conversion to JSON array
|
|
std::cout << output.dump(2) << std::endl; // pretty-print with 2-space indent
|
|
}
|
|
|
|
void Conversation::sendReply(BMessage message) {
|
|
BLooper *looper = replyTarget->Looper(); // get the looper it's attached to
|
|
|
|
if (looper != nullptr) {
|
|
BMessenger messenger(replyTarget, looper);
|
|
messenger.SendMessage(&message);
|
|
} else {
|
|
printf("Handler not attached to a looper.\n");
|
|
}
|
|
}
|
|
|
|
std::string Conversation::buildHistoryInfoLine() {
|
|
|
|
std::string info = "" + std::to_string(_messageHistory.size()) + " messages";
|
|
|
|
return info;
|
|
|
|
}
|
|
|
|
void Conversation::ClearHistory() {
|
|
|
|
printf("Cleared history\n");
|
|
_messageHistory.clear();
|
|
}
|
|
|
|
std::vector<std::string>
|
|
Conversation::FilterTextModels(const json &modelsJson) {
|
|
std::vector<std::string> result;
|
|
std::regex pattern("gpt|text|curie|babbage|ada");
|
|
|
|
for (const auto &model : modelsJson["data"]) {
|
|
std::string id = model["id"];
|
|
if (std::regex_search(id, pattern) &&
|
|
id.find("audio") == std::string::npos &&
|
|
id.find("image") == std::string::npos &&
|
|
id.find("image") == std::string::npos &&
|
|
id.find("tts") == std::string::npos &&
|
|
id.find("embed") == std::string::npos &&
|
|
id.find("-20") == std::string::npos &&
|
|
id.find("preview") == std::string::npos &&
|
|
id.find("transcribe") == std::string::npos &&
|
|
id.find("dall-e") == std::string::npos) {
|
|
result.push_back(id);
|
|
}
|
|
}
|
|
|
|
std::sort(result.begin(), result.end(), std::greater<>()); // inverse alphabetical to get gpt-4 on top
|
|
return result;
|
|
}
|
|
|
|
void Conversation::MessageReceived(BMessage *message) {
|
|
switch (message->what) {
|
|
//.. case B_HTTP_DATA_RECEIVED: {
|
|
// break;
|
|
// }
|
|
|
|
case UrlEvent::HostNameResolved: {
|
|
printf("Host name resolved\n");
|
|
auto name = message->GetString(UrlEventData::HostName);
|
|
message->PrintToStream();
|
|
|
|
//_infoView->SetText("Hostname resolve...");
|
|
//_infoView->SetText(name);
|
|
//_progress->SetTo(5);
|
|
|
|
} break;
|
|
|
|
case UrlEvent::ConnectionOpened: {
|
|
printf("ConnectionOpened\n");
|
|
//_progress->SetTo(10);
|
|
//_infoView->SetText("connection opened...");
|
|
} break;
|
|
|
|
case UrlEvent::ResponseStarted: {
|
|
printf("ResponseStarted\n");
|
|
//_progress->SetTo(14);
|
|
//_infoView->SetText("ResponseStarted...");
|
|
} break;
|
|
|
|
case UrlEvent::HttpRedirect: {
|
|
printf("HttpRedirect\n");
|
|
//_progress->SetTo(16);
|
|
//_infoView->SetText("HttpRedirect...");
|
|
} break;
|
|
|
|
case UrlEvent::RequestCompleted: {
|
|
printf("RequestCompleted\n");
|
|
auto identifier = message->GetInt32(UrlEventData::Id, -1);
|
|
if (_lastResult->Identity() == identifier) {
|
|
// The following call will not block, because we have been notified
|
|
// that the request is done.
|
|
BHttpBody body = _lastResult->Body();
|
|
if (body.text.has_value()) {
|
|
|
|
|
|
try {
|
|
|
|
//printf("full Reply as text:%s",body.text.value().String());
|
|
auto fullBody =body.text.value().String();
|
|
json parsed = json::parse(fullBody);
|
|
printf("Parsed..\n");
|
|
|
|
std::string objType = parsed["object"];
|
|
// printf("Reply of type object :%s\n", objType.c_str());
|
|
|
|
if (objType == "list") {
|
|
// printf("full Reply as text:%s",body.text.value().String());
|
|
|
|
std::vector validModels = FilterTextModels(parsed);
|
|
PrintAsJsonArray(validModels);
|
|
BMessage msg(kModelsReceived);
|
|
|
|
for (const auto &model : validModels) {
|
|
msg.AddString("model", model.c_str());
|
|
}
|
|
sendReply(msg);
|
|
|
|
// std::string content =
|
|
//parsed["choices"][0]["message"]["content"];
|
|
|
|
}
|
|
|
|
else if (objType == "chat.completion") {
|
|
std::string content = parsed["choices"][0]["message"]["content"];
|
|
|
|
|
|
_messageHistory.push_back({
|
|
{"role", "assistant"},
|
|
{"content", content}
|
|
});
|
|
|
|
|
|
// printf("we got content:%s",content.c_str());
|
|
BMessage message(kSendReply);
|
|
message.AddString("text", BString(content.c_str()));
|
|
message.AddString("json", BString(fullBody));
|
|
sendReply(message);
|
|
}
|
|
|
|
} catch (const std::exception &e) {
|
|
fprintf(stderr, "Error parsing JSON: %s\n", e.what());
|
|
std::string content = "Error parsing JSON, wrong model ?";
|
|
BMessage message(kSendReply);
|
|
message.AddString("text", BString(content.c_str()));
|
|
sendReply(message);
|
|
}
|
|
|
|
} else {
|
|
BMessage message(kSendReply);
|
|
message.AddString("text", "EMPTY BODY");
|
|
sendReply(message);
|
|
}
|
|
}
|
|
}
|
|
|
|
break;
|
|
|
|
case UrlEvent::HttpStatus: {
|
|
|
|
printf("HttpStatus\n");
|
|
//_infoView->SetText("HttpStatus received");
|
|
//_progress->SetTo(20);
|
|
|
|
} break;
|
|
|
|
case UrlEvent::BytesWritten: {
|
|
// _infoView->SetText("Some bytes written..");
|
|
auto identifier = message->GetInt32(UrlEventData::Id, -1);
|
|
if (_lastResult->Identity() == identifier) {
|
|
off_t numBytes = message->GetInt64(UrlEventData::NumBytes, 0);
|
|
off_t totalBytes = message->GetInt64(UrlEventData::TotalBytes, 0);
|
|
// _progress->SetTo(numBytes);
|
|
//_progress->SetMaxValue(totalBytes);
|
|
}
|
|
} break;
|
|
|
|
case UrlEvent::DownloadProgress: {
|
|
auto identifier = message->GetInt32(UrlEventData::Id, -1);
|
|
if (_lastResult->Identity() == identifier) {
|
|
off_t nn = message->GetInt64(UrlEventData::NumBytes, 0);
|
|
off_t totalBytes = message->GetInt64(UrlEventData::TotalBytes, 0);
|
|
//_progress->SetTo(nn);
|
|
//_progress->SetMaxValue(totalBytes);
|
|
//_infoView->SetText("Download Progress..");
|
|
}
|
|
} break;
|
|
|
|
default:
|
|
BHandler::MessageReceived(message);
|
|
break;
|
|
}
|
|
}
|
|
|
|
std::string Conversation::buildBearerKey() {
|
|
|
|
// if the API key file contains a new line bhttpfields will crash with invalid
|
|
// content .end() requires include algorithm
|
|
std::string key = _apiKey.String();
|
|
key.erase(std::remove(key.begin(), key.end(), '\n'), key.end());
|
|
key.erase(std::remove(key.begin(), key.end(), '\r'), key.end());
|
|
|
|
std::string bearer = std::string("Bearer ") + std::string(key);
|
|
return bearer;
|
|
}
|
|
|
|
void Conversation::loadModels() {
|
|
|
|
auto url = BUrl("https://api.openai.com/v1/models");
|
|
BHttpRequest request = BHttpRequest(url);
|
|
request.SetMethod(BHttpMethod::Get);
|
|
|
|
BHttpFields fields = BHttpFields();
|
|
fields.AddField("Authorization", buildBearerKey());
|
|
request.SetFields(fields);
|
|
|
|
BString mime = BString("application/json");
|
|
|
|
printf("Sending Prompt to server: %s\n", url.UrlString().String());
|
|
_lastResult = _sharedSession.Execute(std::move(request), nullptr, this);
|
|
|
|
if (_lastResult) {
|
|
printf("Result has identity: %d\n", _lastResult->Identity());
|
|
}
|
|
}
|
|
|
|
void Conversation::setModel(const std::string &model) {
|
|
|
|
_activeModel = model;
|
|
|
|
printf("Conversation will use model:%s\n", _activeModel.c_str());
|
|
}
|
|
|
|
void Conversation::ask(const std::string &prompt) {
|
|
|
|
_messageHistory.push_back({
|
|
{"role", "user"},
|
|
{"content", prompt}
|
|
});
|
|
|
|
// printf("Asking prompt: %s",prompt.c_str());
|
|
|
|
if (_lastResult)
|
|
_sharedSession.Cancel(_lastResult->Identity());
|
|
|
|
auto url = BUrl("https://api.openai.com/v1/chat/completions");
|
|
BHttpRequest request = BHttpRequest(url);
|
|
request.SetMethod(BHttpMethod::Post);
|
|
//Allow up to 2 minute before timeout, it can be long depending on load or complexity of prompt
|
|
request.SetTimeout(120*1000000);
|
|
|
|
BHttpFields fields = BHttpFields();
|
|
fields.AddField("Authorization", buildBearerKey());
|
|
// fields.AddField("Content-Type", "application/json"); //NO, this will
|
|
// crash, we set it in request
|
|
request.SetFields(fields);
|
|
|
|
// WE PASS THE WHOLE HISTORY to keep context, as recommended for this stateless API!
|
|
// json bodyJson = {{"model", _activeModel},
|
|
// {"messages", {{{"role", "user"}, {"content", prompt}}}}};
|
|
|
|
json bodyJson = {
|
|
{"model", _activeModel},
|
|
{"messages", _messageHistory}
|
|
};
|
|
|
|
std::string body = bodyJson.dump();
|
|
|
|
BString mime = BString("application/json");
|
|
|
|
auto memoryIO = std::make_unique<BMemoryIO>(body.c_str(), body.size());
|
|
request.SetRequestBody(std::move(memoryIO), "application/json", body.size());
|
|
|
|
printf("Sending Prompt to server: %s\n", url.UrlString().String());
|
|
_lastResult = _sharedSession.Execute(std::move(request), nullptr, this);
|
|
|
|
if (_lastResult) {
|
|
printf("Result has identity: %d\n", _lastResult->Identity());
|
|
}
|
|
}
|
|
|
|
BString Conversation::ReadOpenAIKey() {
|
|
|
|
BPath configPath;
|
|
if (find_directory(B_USER_SETTINGS_DIRECTORY, &configPath) != B_OK)
|
|
return "error: couldn't find config directory";
|
|
|
|
// /boot/home/config/openai_key
|
|
configPath.Append("openai_key");
|
|
|
|
BFile file(configPath.Path(), B_READ_ONLY);
|
|
|
|
printf("full path:%s\n", configPath.Path());
|
|
if (file.InitCheck() != B_OK) {
|
|
validKey = false;
|
|
return "error: couldn't open key file ";
|
|
}
|
|
|
|
off_t size;
|
|
file.GetSize(&size);
|
|
|
|
char *buffer = new char[size + 1];
|
|
file.Read(buffer, size);
|
|
buffer[size] = '\0'; // null-terminate
|
|
|
|
BString result(buffer);
|
|
delete[] buffer;
|
|
|
|
validKey = true;
|
|
|
|
return result;
|
|
}
|