Skip to content

Commit

Permalink
feat: added working action execution from query
Browse files Browse the repository at this point in the history
  • Loading branch information
wrola committed Feb 2, 2024
1 parent 937cdd5 commit 8a7fffc
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 23 deletions.
13 changes: 13 additions & 0 deletions src/memory/api/dto/memory-input.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { IsArray, IsNotEmpty, IsString } from 'class-validator';

export class MemoryInputDto {
@IsString()
@IsNotEmpty()
content: string;

@IsString()
name: string;

@IsArray()
tags: Array<string>;
}
1 change: 1 addition & 0 deletions src/memory/api/dto/memory-output.dto.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export class MemoryOutputDto {}
15 changes: 15 additions & 0 deletions src/memory/api/memory.controller.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { Body, Controller, Inject, Post } from '@nestjs/common';
import { IMemoryService, MEMORY_SERVICE } from '../memory.service';
import { MemoryOutputDto } from './dto/memory-output.dto';
import { MemoryInputDto } from './dto/memory-input.dto';

@Controller('memory')
export class MemoryController {
constructor(@Inject(MEMORY_SERVICE) private memoryService: IMemoryService) {}

@Post('/save')
async learn(@Body() body: MemoryInputDto): Promise<MemoryOutputDto> {
// TODO create tags and name if there not present
return await this.memoryService.add(body);
}
}
6 changes: 4 additions & 2 deletions src/memory/core/entities/memory.entity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ export class Memory {

static create(memoryInput: MemoryInput): Memory {
const memory = new Memory();
memory.id = v4();
memory.id = memoryInput.id ? memoryInput.id : v4();
memory.content = memoryInput.content;
memory.name = memoryInput.name;
memory.tags = memoryInput.tags;
memory.active = true;
if (memoryInput.reflection) memory.reflection = memoryInput.reflection;
memory.reflection = memoryInput.reflection ? memoryInput.reflection : null;

return memory;
}
}
Expand All @@ -41,4 +42,5 @@ export type MemoryInput = {
name: string;
tags: Array<string>;
reflection?: string;
id?: string;
};
2 changes: 1 addition & 1 deletion src/memory/infrastructure/qdrant.client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ export interface IQdrantClient {
data: Record<string, unknown>,
): Promise<Array<QdrantDocs>>;
createCollection(): Promise<unknown>;
getCollection(): Promise<unknown>;
getCollection(): Promise<Record<string, unknown>>;
upsert(
collectionName: string,
data: Record<string, unknown>,
Expand Down
5 changes: 4 additions & 1 deletion src/memory/init-memory.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,27 @@ export const INIT_MEMORY = Symbol('INITIAL_MEMORY');

export class InitialMemory {
constructor(@Inject(MEMORY_SERVICE) private service: IMemoryService) {}

async load() {
await Promise.all(
defaulMemories.map(async (memory) => await this.service.add(memory)),
);
Logger.log('Init memories added');
}
} // TODO make it one-timer
}
const defaulMemories: Array<MemoryInput> = [
{
name: 'George',
content:
"I'm George. The person who is very kind and generous. I'm also very smart and funny.",
tags: ['self-perception', 'personality', 'self', 'George'],
id: 'b6e33182-9efa-4ecd-9b39-c74acdc1b853',
},
{
name: 'Wojtek',
content:
'Wojtek is most common user, that is really happy to talk to you, use skill and knowledge to help him with issues that he raise to you and have fun working together on problems to tackle down',
tags: ['Wojtek', 'user', 'wojtek'],
id: 'e66f582c-058b-49bc-95ca-350f7aa951ee',
},
];
31 changes: 20 additions & 11 deletions src/memory/memory.service.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Inject, Injectable, Logger } from '@nestjs/common';
import { ImATeapotException, Inject, Injectable, Logger } from '@nestjs/common';
import {
IMemoryRepository,
MEMORY_REPOSITORY,
Expand Down Expand Up @@ -38,7 +38,7 @@ export class MemoryService implements IMemoryService {
name: newMemory.name,
},
});
// how create a memory initial one?

const [embedding] = await this.embeddingProducer.embedDocuments([
documentedMemory.pageContent,
]);
Expand All @@ -65,7 +65,7 @@ export class MemoryService implements IMemoryService {
const queryEmbedding = await this.getEmebed(query);
const documentedMemories = await this.qdrantClient.search(MEMORIES, {
vector: queryEmbedding,
limit: 5,
limit: 3,
});
const rerankMemories = await this.rerank(query, documentedMemories);
return rerankMemories;
Expand Down Expand Up @@ -129,24 +129,32 @@ export class MemoryService implements IMemoryService {
temperature: 0,
maxConcurrency: 1,
});
// TODO check if plan working during action execution

if (!actions.length) {
new ImATeapotException('No actions possible');
}

const { content: uuid } = await model.invoke([
new SystemMessage(`As George, you need to pick a single action that is the most relevant to the user's query and context below. Your only job is to return UUID of this action and nothing else.
conversation context###${context
.map((doc) => doc[0].pageContent)
.join('\n\n')}###
conversation context###${
context.length
? context.map((doc) => doc.payload.pageContent).join('\n\n')
: ''
}###
available actions###${actions
.map(
(action) =>
`(${action[0].payload.uuid}) + ${action[0].pageContent}`,
)
.map((action) => `(${action.id}) + ${action.payload.pageContent}`)
.join('\n\n')}###
`),
new HumanMessage(query + '### Pick an action (UUID): '),
]);

return uuid as string;
}
async isMemoryReady(): Promise<boolean> {
const collectionsState = await this.qdrantClient.getCollection();

return collectionsState.status === 'ok';
}
}

export const MEMORY_SERVICE = Symbol('MEMORY_SERVICE');
Expand All @@ -156,4 +164,5 @@ export interface IMemoryService {
restoreMemory(queryEmbedding): Promise<Array<unknown>>;
add(memoryInput: MemoryInput): Promise<Memory>;
plan(query: string, actions: any[], context: unknown): Promise<string>;
isMemoryReady(): Promise<boolean>; // TODO add checker if qdrant is ready
}
1 change: 1 addition & 0 deletions src/skills/core/handlers/add-skill.handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export class AddSkillHandler implements SkillHandler {
this.payload?.webhook,
this.payload?.tags,
this.payload?.schema,
this.payload?.id,
);
await this.skillRepository.save(skill);

Expand Down
6 changes: 4 additions & 2 deletions src/skills/core/handlers/perfom-action.handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ export class PerformAction implements SkillHandler {
}

const content = await chat.invoke(messages);
const result = this._parseFunctionCall(content);
const result = skill.schema
? this._parseFunctionCall(content)
: { args: content.content };

if (skill.webhook && result.args) {
if (skill.webhook) {
try {
const response = await fetch(skill.webhook, {
method: 'POST',
Expand Down
11 changes: 5 additions & 6 deletions src/skills/core/skill.entity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@ export class Skill {
@CreateDateColumn()
createdAt: Date;

static create(name, description, webhook?, tags?, schema?) {
static create(name, description, webhook?, tags?, schema?, id?) {
const skill = new Skill();
skill.id = v4();
skill.id = id ? id : v4();
skill.name = name;
skill.description = description;

if (webhook) skill.webhook = webhook;
if (tags) skill.tags = tags;
if (schema) skill.schema = schema;
skill.webhook = webhook ? webhook : null;
skill.tags = tags ? tags : null;
skill.schema = schema ? schema : null;

return skill;
}
Expand Down
2 changes: 2 additions & 0 deletions src/skills/skills.seed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ export class SkillSeedService {
tags: ['memorize', 'memory', 'remember', 'skill'],
webhook: 'http://localhost:3000/learn',
schema: learnSchema,
id: '8df3d811-d459-4880-b651-7c4d4836b029',
});
memorySkill.setPayload({
name: Skills.MEMORY,
description: SkillsDescription.MEMORY,
tags: ['memorize', 'memory', 'remember'],
id: '62075b91-d895-4e61-8a4c-b8b89ff909fb',
});

return Promise.all(
Expand Down

0 comments on commit 8a7fffc

Please sign in to comment.