Skip to content

Commit

Permalink
feat: mesh topology, distributed all layers. (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored Nov 21, 2024
1 parent 0d1121e commit 8b1cf89
Show file tree
Hide file tree
Showing 24 changed files with 792 additions and 938 deletions.
6 changes: 0 additions & 6 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ jobs:
make tokenizer-test
make commands-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
run: ./funcs-test
- name: quants-test
Expand All @@ -44,8 +43,6 @@ jobs:
run: ./commands-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
run: ./grok1-tasks-test

build-windows:
name: Windows
Expand All @@ -66,7 +63,6 @@ jobs:
make tokenizer-test
make commands-test
make llama2-tasks-test
make grok1-tasks-test
- name: funcs-test
run: ./funcs-test
- name: quants-test
Expand All @@ -77,5 +73,3 @@ jobs:
run: ./commands-test
- name: llama2-tasks-test
run: ./llama2-tasks-test
- name: grok1-tasks-test
run: ./grok1-tasks-test
12 changes: 4 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,17 @@ tasks: src/tasks.cpp
$(CXX) $(CXXFLAGS) -c src/tasks.cpp -o tasks.o
llama2-tasks: src/llama2-tasks.cpp
$(CXX) $(CXXFLAGS) -c src/llama2-tasks.cpp -o llama2-tasks.o
grok1-tasks: src/grok1-tasks.cpp
$(CXX) $(CXXFLAGS) -c src/grok1-tasks.cpp -o grok1-tasks.o
mixtral-tasks: src/mixtral-tasks.cpp
$(CXX) $(CXXFLAGS) -c src/mixtral-tasks.cpp -o mixtral-tasks.o
tokenizer: src/tokenizer.cpp
$(CXX) $(CXXFLAGS) -c src/tokenizer.cpp -o tokenizer.o
app: src/app.cpp
$(CXX) $(CXXFLAGS) -c src/app.cpp -o app.o

dllama: src/apps/dllama/dllama.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama: src/apps/dllama/dllama.cpp utils quants funcs commands socket transformer tasks llama2-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama/dllama.cpp -o dllama utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
dllama-api: src/apps/dllama-api/dllama-api.cpp utils quants funcs commands socket transformer tasks llama2-tasks mixtral-tasks tokenizer app
$(CXX) $(CXXFLAGS) src/apps/dllama-api/dllama-api.cpp -o dllama-api utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o mixtral-tasks.o tokenizer.o app.o $(LIBS)
socket-benchmark: src/apps/socket-benchmark/socket-benchmark.cpp socket
$(CXX) $(CXXFLAGS) src/apps/socket-benchmark/socket-benchmark.cpp -o socket-benchmark socket.o $(LIBS)

Expand All @@ -52,5 +50,3 @@ commands-test: src/commands-test.cpp funcs commands utils quants transformer soc
$(CXX) $(CXXFLAGS) src/commands-test.cpp -o commands-test funcs.o commands.o utils.o quants.o transformer.o socket.o $(LIBS)
llama2-tasks-test: src/llama2-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks tokenizer
$(CXX) $(CXXFLAGS) src/llama2-tasks-test.cpp -o llama2-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o tokenizer.o $(LIBS)
grok1-tasks-test: src/grok1-tasks-test.cpp utils quants funcs commands socket transformer tasks llama2-tasks grok1-tasks tokenizer
$(CXX) $(CXXFLAGS) src/grok1-tasks-test.cpp -o grok1-tasks-test utils.o quants.o funcs.o commands.o socket.o transformer.o tasks.o llama2-tasks.o grok1-tasks.o tokenizer.o $(LIBS)
5 changes: 3 additions & 2 deletions converter/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy==1.23.5
torch==2.0.1
safetensors==0.4.2
pytorch==2.0.1
safetensors==0.4.2
sentencepiece==0.1.99
12 changes: 4 additions & 8 deletions src/app.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.seed = (unsigned long long)time(NULL);
args.chatTemplateType = TEMPLATE_UNKNOWN;
args.maxSeqLen = 0;
args.useDiscForKvCache = false;

args.packetAlignment = 0;
int i = 1;
if (hasMode && argc > 1) {
args.mode = argv[1];
Expand Down Expand Up @@ -102,8 +101,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.chatTemplateType = parseChatTemplateType(value);
} else if (strcmp(name, "--max-seq-len") == 0) {
args.maxSeqLen = (unsigned int)atoi(value);
} else if (strcmp(name, "--kv-cache-storage") == 0) {
args.useDiscForKvCache = strcmp(value, "disc") == 0;
} else if (strcmp(name, "--packet-alignment") == 0) {
args.packetAlignment = (size_t)atoi(value);
} else {
printf("Unknown option %s\n", name);
exit(EXIT_FAILURE);
Expand All @@ -114,8 +113,6 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {

TransformerArch TransformerArchFactory::create(TransformerSpec* spec) {
if (spec->archType == LLAMA) return buildLlamaArch(spec);
if (spec->archType == GROK1) return buildGrok1Arch(spec);
if (spec->archType == MIXTRAL) return buildMixtralArch(spec);
printf("Unsupported arch type: %d\n", spec->archType);
exit(EXIT_FAILURE);
}
Expand All @@ -128,7 +125,7 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
throw std::runtime_error("Tokenizer is required");
}

SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts);
SocketPool* socketPool = SocketPool::connect(args->nWorkers, args->workerHosts, args->workerPorts, args->packetAlignment);
unsigned int nSlices = args->nWorkers + 1;

TransformerSpec spec = Transformer::loadSpecFromFile(args->modelPath, nSlices, args->maxSeqLen, args->weightsFloatType, args->bufferFloatType);
Expand All @@ -140,7 +137,6 @@ void App::run(AppArgs* args, void (*program)(Inference* inference, SocketPool* s
}

TransformerConfig config;
config.useDiscForKvCache = args->useDiscForKvCache;

Transformer transformer = Transformer::loadRootFromFile(args->modelPath, &spec, &config, socketPool);
socketPool->setTurbo(true);
Expand Down
3 changes: 1 addition & 2 deletions src/app.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
#include "transformer.hpp"
#include "tasks.hpp"
#include "llama2-tasks.hpp"
#include "grok1-tasks.hpp"
#include "mixtral-tasks.hpp"
#include "tokenizer.hpp"

class AppArgs {
public:
char* mode;
int nThreads;
bool useDiscForKvCache;
size_t packetAlignment;

// inference
char* modelPath;
Expand Down
67 changes: 49 additions & 18 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ enum class HttpMethod {

class HttpRequest {
public:
static HttpRequest read(Socket& socket) {
HttpRequest req(&socket);
static HttpRequest read(int serverSocket) {
HttpRequest req(serverSocket);

std::vector<char> httpRequest = socket.readHttpRequest();
std::vector<char> httpRequest = req.readHttpRequest();
// Parse the HTTP request
std::string data = std::string(httpRequest.begin(), httpRequest.end());

Expand Down Expand Up @@ -89,16 +89,48 @@ class HttpRequest {
}

private:
Socket* socket;
int serverSocket;
public:
std::string path;
std::unordered_map<std::string, std::string> headers;
std::string body;
json parsedJson;
HttpMethod method;

HttpRequest(Socket* socket) {
this->socket = socket;
HttpRequest(int serverSocket) {
this->serverSocket = serverSocket;
}

std::vector<char> readHttpRequest() {
std::vector<char> httpRequest;
char buffer[1024 * 1024]; // TODO: this should be refactored asap
ssize_t bytesRead;

// Peek into the socket buffer to check available data
bytesRead = recv(serverSocket, buffer, sizeof(buffer), MSG_PEEK);
if (bytesRead <= 0) {
// No data available or error occurred
if (bytesRead == 0) {
// No more data to read
return httpRequest;
} else {
// Error while peeking
throw std::runtime_error("Error while peeking into socket");
}
}

// Resize buffer according to the amount of data available
std::vector<char> peekBuffer(bytesRead);
bytesRead = recv(serverSocket, peekBuffer.data(), bytesRead, 0);
if (bytesRead <= 0) {
// Error while reading
throw std::runtime_error("Error while reading from socket");
}

// Append data to httpRequest
httpRequest.insert(httpRequest.end(), peekBuffer.begin(), peekBuffer.end());

return httpRequest;
}

std::string getMethod() {
Expand All @@ -111,7 +143,7 @@ class HttpRequest {

void writeNotFound() {
const char* data = "HTTP/1.1 404 Not Found\r\n";
socket->write(data, strlen(data));
writeSocket(serverSocket, data, strlen(data));
}

void writeJson(std::string json) {
Expand All @@ -120,7 +152,7 @@ class HttpRequest {
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
std::string data = buffer.str();
socket->write(data.c_str(), data.size());
writeSocket(serverSocket, data.c_str(), data.size());
}

void writeStreamStartChunk() {
Expand All @@ -130,19 +162,19 @@ class HttpRequest {
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";
std::string data = buffer.str();
socket->write(data.c_str(), data.size());
writeSocket(serverSocket, data.c_str(), data.size());
}

void writeStreamChunk(const std::string data) {
std::ostringstream buffer;
buffer << std::hex << data.size() << "\r\n" << data << "\r\n";
std::string d = buffer.str();
socket->write(d.c_str(), d.size());
writeSocket(serverSocket, d.c_str(), d.size());
}

void writeStreamEndChunk() {
const char* endChunk = "0000\r\n\r\n";
socket->write(endChunk, strlen(endChunk));
writeSocket(serverSocket, endChunk, strlen(endChunk));
}
};

Expand Down Expand Up @@ -260,9 +292,6 @@ class ApiServer {
std::vector<ChatMessage> deltaPrompt = params.messages;
naiveCache.resolveDeltaPrompt(deltaPrompt, startPos);

printf("🔸");
fflush(stdout);

size_t nInputItems = deltaPrompt.size();
ChatItem inputItems[nInputItems];
for (size_t i = 0; i < nInputItems; i++) {
Expand All @@ -271,6 +300,8 @@ class ApiServer {
}

std::string inputPrompt = chatTemplate->generate(nInputItems, inputItems, true);
printf("🔹%s🔸", inputPrompt.c_str());

int promptLength = inputPrompt.size();
int nPromptTokens;
int promptTokens[promptLength + 3];
Expand Down Expand Up @@ -393,7 +424,7 @@ void handleModelsRequest(HttpRequest& request) {
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
SocketServer* server = new SocketServer(args->port);
int serverSocket = createServerSocket(args->port);

TokenizerChatStops stops(tokenizer);
ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]);
Expand All @@ -417,8 +448,8 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,

while (true) {
try {
Socket client = server->accept();
HttpRequest request = HttpRequest::read(client);
int clientSocket = acceptSocket(serverSocket);
HttpRequest request = HttpRequest::read(clientSocket);
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
Router::resolve(request, routes);
} catch (ReadSocketException& ex) {
Expand All @@ -428,7 +459,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
}
}

delete server;
closeServerSocket(serverSocket);
}

int main(int argc, char *argv[]) {
Expand Down
10 changes: 5 additions & 5 deletions src/apps/dllama/dllama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,16 @@ void worker(AppArgs* args) {
}

TransformerConfig config;
config.useDiscForKvCache = args->useDiscForKvCache;

SocketServer server(args->port);
Socket socket = server.accept();
SocketPool* socketPool = SocketPool::serve(args->port);
TransformerSpec spec;
Transformer transformer = Transformer::loadSlice(&spec, &config, &socket);
Transformer transformer = Transformer::loadSlice(&spec, &config, socketPool);
TransformerArch arch = TransformerArchFactory::create(&spec);

Worker worker = Worker(&arch, args->nThreads, &transformer, &socket);
Worker worker = Worker(&arch, args->nThreads, &transformer, socketPool);
worker.work();

delete socketPool;
}

int main(int argc, char *argv[]) {
Expand Down
23 changes: 8 additions & 15 deletions src/apps/socket-benchmark/socket-benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <sys/socket.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <stdexcept>
#include <cassert>

using namespace std::chrono;

Expand All @@ -15,14 +17,6 @@ unsigned int nAttempts = 5000;
int port = 7721;
bool testTcp = true;

void setNonBlocking(int socket) {
//int flags = fcntl(socket, F_GETFL, 0);
//if (fcntl(socket, F_SETFL, flags |= O_NONBLOCK) < 0)
// throw std::runtime_error("Cannot set socket flags");
}

#define MAX_PACKAGE_SIZE 1280

char pktinfo[4096] = {0};

void readUdpSocket(int socket, char* buffer, unsigned int size, struct sockaddr_in* clientAddr, socklen_t* clientAddrLen) {
Expand Down Expand Up @@ -88,8 +82,9 @@ void server() {
if (testTcp) {
printf("TCP test\n");

SocketServer server(port);
Socket socket = server.accept();
SocketPool* pool = SocketPool::serve(port);
assert(pool->nSockets == 1);

for (long i = 0; i < packageSizesCount; i++) {
unsigned int currentPackageSize = packageSizes[i];

Expand All @@ -98,9 +93,9 @@ void server() {
long long totalTime = 0; // [us]
for (long a = 0; a < nAttempts; a++) {
auto t0 = high_resolution_clock::now();
socket.read(buffer, currentPackageSize);
pool->read(0, buffer, currentPackageSize);
auto t1 = high_resolution_clock::now();
socket.write(buffer, currentPackageSize);
pool->write(0, buffer, currentPackageSize);
auto t2 = high_resolution_clock::now();

totalReadTime += duration_cast<microseconds>(t1 - t0).count();
Expand All @@ -127,7 +122,6 @@ void server() {
serverAddr.sin_family = AF_INET;
serverAddr.sin_addr.s_addr = INADDR_ANY;
serverAddr.sin_port = htons(port);
setNonBlocking(serverSocket);

if (bind(serverSocket, (struct sockaddr *)&serverAddr, sizeof(serverAddr)) < 0)
throw std::runtime_error("Cannot bind socket");
Expand Down Expand Up @@ -176,7 +170,7 @@ void client(char* host) {
int* ports = new int[1];
ports[0] = port;

SocketPool* pool = SocketPool::connect(1, hosts, ports);
SocketPool* pool = SocketPool::connect(1, hosts, ports, 0);
pool->setTurbo(true);

for (long i = 0; i < packageSizesCount; i++) {
Expand Down Expand Up @@ -216,7 +210,6 @@ void client(char* host) {
serverAddr.sin_family = AF_INET;
serverAddr.sin_port = htons(port);
serverAddr.sin_addr.s_addr = inet_addr(host);
setNonBlocking(clientSocket);

for (long i = 0; i < packageSizesCount; i++) {
unsigned int currentPackageSize = packageSizes[i];
Expand Down
2 changes: 1 addition & 1 deletion src/commands.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class MatmulCommand {
unsigned int n;
unsigned int d;
size_t cpuSize;
void* cpuWeights;
public:
void* cpuWeights;
MatmulCommand(const unsigned int n, const unsigned int d, const FloatType inputFloatType, const FloatType weightsFloatType);
~MatmulCommand();
size_t loadWeights(const void* source);
Expand Down
Loading

0 comments on commit 8b1cf89

Please sign in to comment.