diff --git a/.eslintrc.json b/.eslintrc.json index f076ee20..98e7f8aa 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -5,13 +5,41 @@ "ecmaVersion": 6, "sourceType": "module" }, - "extends": [ "prettier" ], - "plugins": [ - "@typescript-eslint", - "prettier" - ], + "extends": ["prettier"], + "plugins": ["@typescript-eslint", "prettier"], "rules": { - "@typescript-eslint/naming-convention": "warn", + /* Most of the rules here are just the default ones, + * the main changes are: + * - required UPPER_CASE for enum members; + * - allowed UPPER_CASE in general. + */ + "@typescript-eslint/naming-convention": [ + "warn", + { "selector": "default", "format": ["camelCase"] }, + { + "selector": "variable", + "format": ["camelCase", "UPPER_CASE"], + "leadingUnderscore": "allow" + }, + { + "selector": "property", + "format": ["camelCase", "UPPER_CASE"], + "leadingUnderscore": "allow" + }, + { + "selector": "parameter", + "format": ["camelCase"], + "leadingUnderscore": "allow" + }, + { + "selector": "typeLike", + "format": ["PascalCase"] + }, + { + "selector": "enumMember", + "format": ["UPPER_CASE"] + } + ], "@typescript-eslint/semi": "warn", "curly": "warn", "eqeqeq": "warn", @@ -19,9 +47,5 @@ "semi": "off", "prettier/prettier": "warn" }, - "ignorePatterns": [ - "out", - "dist", - "**/*.d.ts" - ] + "ignorePatterns": ["out", "dist", "**/*.d.ts"] } diff --git a/.github/workflows/coqpilot.yml b/.github/workflows/coqpilot.yml index bb58c628..47cce251 100644 --- a/.github/workflows/coqpilot.yml +++ b/.github/workflows/coqpilot.yml @@ -2,46 +2,63 @@ name: Build and Test on: push: - branches: [ main ] + branches: + - main pull_request: - branches: [ main ] + branches: + - main + - 'v[0-9]+.[0-9]+.[0-9]+-dev' + workflow_dispatch: + env: - coqlsppath: "coq-lsp" + coqlsp-path: "coq-lsp" + coqlsp-version: "0.1.8+8.19" jobs: build: strategy: matrix: - os: [ubuntu-latest] + os: + - ubuntu-latest ocaml-compiler: - "4.14" + runs-on: ${{ matrix.os }} + outputs: vsixPath: ${{ steps.packageExtension.outputs.vsixPath }} + steps: - name: Checkout tree uses: actions/checkout@v4 + + # For some reason, the most significant thing for caching opam dependencies properly + # is `dune-cache: true` instead of this caching action. + - name: Restore cached opam dependencies + id: cache-opam + uses: actions/cache@v3 + with: + path: ~/.opam/ + key: opam-${{ matrix.os }}-${{ matrix.ocaml-compiler }}-${{ env.coqlsp-version }} + restore-keys: opam-${{ matrix.os }}-${{ matrix.ocaml-compiler }}- - name: Set-up OCaml ${{ matrix.ocaml-compiler }} uses: ocaml/setup-ocaml@v2 with: ocaml-compiler: ${{ matrix.ocaml-compiler }} - + dune-cache: true + - name: Install opam dependencies env: OPAMYES: true run: | opam install coq-lsp.0.1.8+8.19 eval $(opam env) - - - name: Opam eval - run: eval $(opam env) - name: Install Node.js - uses: actions/setup-node@v3 + uses: actions/setup-node@v4 with: - node-version: 16.x - + node-version-file: ".nvmrc" - run: npm ci - name: Check coq-lsp version @@ -64,15 +81,23 @@ jobs: - name: Test on Linux env: - COQ_LSP_PATH: ${{ env.coqlsppath }} - run: xvfb-run -a npm run test-ci + COQ_LSP_PATH: ${{ env.coqlsp-path }} + run: | + eval $(opam env) + xvfb-run -a npm run clean-test if: runner.os == 'Linux' - name: Test not on Linux env: - COQ_LSP_PATH: ${{ env.coqlsppath }} - run: npm run test-ci + COQ_LSP_PATH: ${{ env.coqlsp-path }} + run: | + eval $(opam env) + npm run clean-test if: runner.os != 'Linux' + + - name: Setup tmate session + if: ${{ failure() }} + uses: mxschmitt/action-tmate@v3 - name: Package Extension id: packageExtension @@ -93,9 +118,9 @@ jobs: if: github.ref == 'refs/heads/main' steps: - uses: actions/checkout@v4 - - uses: actions/setup-node@v3 + - uses: actions/setup-node@v4 with: - node-version: 16.x + node-version-file: ".nvmrc" - name: Install Dependencies run: npm ci - name: Download Build Artifact diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 00000000..f203ab89 --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +20.13.1 diff --git a/CHANGELOG.md b/CHANGELOG.md index 86379362..81fde716 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,95 +1,148 @@ -# Change Log +# Changelog + +## 2.2.0 + +### Public changes + +- Support time estimation for LLM services to become available after failure via logging proof-generation requests. This information is shown to the user. +- Set up interaction between `LLMService`-s and UI to report errors that happened during proof-generation. +- Improve LLM services' parameters: their naming, transparency, and description. + - Introduce `modelId` to distinguish a model identifier from the name of an OpenAI / Grazie model. + - Rename `newMessageMaxTokens` to `maxTokensToGenerate` for greater clarity. + - Update the settings description, and make it more user-friendly. +- Significantly improve settings validation. + - Use parameters resolver framework to resolve parameters (with overrides and defaults) and validate them. + - Support messages about parameter resolution: both failures and unexpected overrides. + - Clarify existing error messages. + - Add some general checks: input models have unique `modelId`-s, there are models to generate proofs with. +- Improve interaction with OpenAI. + - Notify the user of configuration errors (invalid model name, incorrect API key, maximum context length exceeded) and connection errors. + - Support resolution of `tokensLimit` and `maxTokensToGenerate` with recommended defaults for known models. +- Fix minor bugs and make minor improvements detected by thorough testing. + +### Internal changes + +- Rework and document LLMService architecture: `LLMServiceInternal`, better facades, powerful typing. +- Introduce hierarchy for LLMService errors. Support proper logging and error handling inside `LLMService`-s. +- Rework settings validation. + - Refactor `SettingsValidationError`, move all input parameters validation to one place and make it coherent. + - Design and implement a powerful and well-documented parameters resolver framework. + +### Testing infrastructure changes + +- Test the LLM Services module thoroughly. +- Improve test infrastructure in general by introducing and structuring utils. +- Fix the issue with building test resources on CI. +- Set up CI debugging, and enable launching CI manually. + Double the speed of CI by setting caches. + +## 2.1.0 -### 2.1.0 Major: - Create a (still in development and improvement) benchmarking system. A guide on how to use it is in the README. -- Conduct an experiment on the performance of different LLMs, using the developed infrastructure. Benchmarking report is located in the [docs folder](etc/docs/benchmarking_report01.md). -- Correctly handle and display settings which occur when the user settings are not correctly set. +- Conduct an experiment on the performance of different LLMs, using the developed infrastructure. The benchmarking report is located in the [docs folder](etc/docs/benchmarking_report01.md). Minor: -- Set order of contributed settings. +- Set the order of contributed settings. - Add a comprehensive user settings guide to the README. -- Fix issue with Grazie service not being able to correctly accept coq ligatures. -- Fix issue that occured when generated proof contained `Proof using {...}.` construct. +- Fix the issue with Grazie service not being able to correctly accept coq ligatures. +- Fix the issue that occurred when the generated proof contained the `Proof using {...}.` construct. -### 2.0.0 -- Added multiple strategies for ranking theorems from the working file. As LLM context window is limited, we sometimes should somehow choose a subset of theorems we want to provide as references to the LLM. Thus, we have made a few strategies for ranking theorems. Now there are only 2 of them, but there are more to come. Now we have a strategy that randomly picks theorems, and also the one that ranks them depending on the distance from the hole. +## 2.0.0 + +- Added multiple strategies for ranking theorems from the working file. As the LLM context window is limited, we sometimes should somehow choose a subset of theorems we want to provide as references to the LLM. Thus, we have made a few strategies for ranking theorems. Now there are only 2 of them, but there are more to come. Now we have a strategy that randomly picks theorems, and also the one that ranks them depending on the distance from the hole. - Now different holes are solved in parallel. This is a huge improvement in terms of performance. -- Implemented multi-round fixing procedure for the proofs from the LLM. It can now be configured in the settings. One can set the amount of attempts for the consequtive proof fixing with compiler feedback. +- Implemented multi-round fixing procedure for the proofs from the LLM. It can now be configured in the settings. One can set the number of attempts for the consecutive proof fixing with compiler feedback. - Added an opportunity to use LM Studio as a language model provider. -- More accurate token count. Tiktoken is now used for open-ai models. -- Different logging levels now supported. +- More accurate token count. Tiktoken is now used for OpenAI models. +- Different logging levels are now supported. - The LLM iterator now supports adding a sequence of models for each service. This brings freedom to the user to experiment with different model parameters. - Now temperature, prompt, and other parameters are configurable in the settings. -### 1.9.0 -- Huge refactoring done. Project re organized. +## 1.9.0 + +- Huge refactoring is done. Project reorganized. + +## 1.5.3 -### 1.5.3 - Fix Grazie service request headers and endpoint. -### 1.5.2 +## 1.5.2 + - Fix issue with double document symbol provider registration (@Alizter, [#9](https://github.com/JetBrains-Research/coqpilot/issues/9)) -### 1.5.1 -- Add support of the Grazie platform as an llm provider. +## 1.5.1 + +- Add support for the Grazie platform as an LLM provider.   -### 1.5.0 -- Now when the hole can be solved by a single tactic solver, using predefined tactics, gpt will not be called, LLMs are now fetched consequently. +## 1.5.0 + +- Now when the hole can be solved by a single tactic solver, using predefined tactics, OpenAI and Grazie will not be called, LLMs are now fetched consequently. - Parallel hole completion is unfortunately postponed due to the implementation complexity. Yet, hopefully, will still be implemented in milestone `2.0.0`. -### 1.4.6 -- Fix issue with plugin breaking after parsing a file containing theorem without `Proof.` keyword. +## 1.4.6 + +- Fix the issue with the plugin breaking after parsing a file containing theorem without the `Proof.` keyword. + +## 1.4.5 -### 1.4.5 - Fix formatting issues when inserting the proof into the editor. -### 1.4.4 -- Do not require a theorem to be `Admitted.` for coqpilot to prove holes in it. -- Correctly parse theorems that are declared with `Definition` keyword. +## 1.4.4 + +- Do not require a theorem to be `Admitted.` for CoqPilot to prove holes in it. +- Correctly parse theorems that are declared with the `Definition` keyword. -### 1.4.3 -- Tiny patch with shuffling of the hole array. +## 1.4.3 -### 1.4.2 -- Now no need to add dot in the end of the tactic, when configuring a single tactic solver. +- Tiny patch with the shuffling of the hole array. + +## 1.4.2 + +- Now no need to add a dot at the end of the tactic, when configuring a single tactic solver. - Automatic reload settings on change in the settings file. Not all settings are reloaded automatically, -but the most ones are. The ones that are not automatically reloaded: `useGpt`, `coqLspPath`, `parseFileOnInit`. -- Added command that solves admits in selected region. Also added that command to the context menu (right click in the editor). +but most ones are. The ones that are not automatically reloaded: `useGpt`, `coqLspPath`, `parseFileOnInit`. +- Added a command that solves admits in a selected region. Also added that command to the context menu (right-click in the editor). - Fix toggle extension. -### 1.4.1 +## 1.4.1 + - Add a possibility to configure a single tactic solver. -### 1.4.0 -- Add command to solve all admitted holes in the file. +## 1.4.0 + +- Add a command to solve all admitted holes in the file. - Fixing bugs with coq-lsp behavior. -### 1.3.1 +## 1.3.1 + - Test coverage increased. -- Refactoring client and ProofView. +- Refactoring client and ProofView.   - Set up CI. -### 1.3.0 +## 1.3.0 + - Fix bug while parsing regarding the updated Fleche doc structure in coq-lsp 0.1.7. - When GPT generated a response containing three consecutive backtick symbols it tended to break the typecheking of the proof. Now solved. - Speed up the file parsing process. -### 1.2.1 +## 1.2.1 + - Add clearing of aux file after parsing. -### 1.2.0 -- Fix error with llm silently failing. Now everything that comes from llm that is not handled inside plugin is presented to user as a message (i.e. incorrect apiKey exception). -- Fix toggle button. -- Fix diagnostics being shown to non coq-lsp plugin coq users. -- Add output stream for the logs in vscode output panel. +## 1.2.0 + +- Fix error with llm silently failing. Now everything that comes from LLM that is not handled inside the plugin is presented to the user as a message (i.e. incorrect apiKey exception). +- Fix the toggle button. +- Fix diagnostics being shown to non-`coq-lsp` plugin coq users. +- Add output stream for the logs in the vscode output panel. -### 1.1.0 +## 1.1.0 -Now proof generation could be run in any position inside the theorem. There is no need to retake file snapshot after each significant file change. -More communication with `coq-lsp` is added. Saperate package `coqlsp-client` no longer used. +Now proof generation could be run in any position inside the theorem. There is no need to retake a file snapshot after each significant file change. +More communication with `coq-lsp` is added. The separate package `coqlsp-client` is no longer used. -### 0.0.1 +## 0.0.1 -Initial release of coqpilot. \ No newline at end of file +The initial release of CoqPilot. diff --git a/README.md b/README.md index c603d1fa..23137036 100644 --- a/README.md +++ b/README.md @@ -74,26 +74,29 @@ With coq-lsp, extension should have everything it needs to run. ### Building locally -To build the extension locally, you will need to have `npm` installed. Then you can clone the repository and run the following commands: +First, clone the Coqpilot repository and navigate into its directory. ```bash -npm install -npm run compile +git clone https://github.com/JetBrains-Research/coqpilot.git +cd coqpilot ``` -To run the extension, you can press `F5` in the vscode. It will open a new window with the extension running. - -To run tests you should go to the `src/test/resources/coqProj` directory and run make: +To build the extension locally, you'll need Node.js installed. The recommended way to manage Node.js versions is by using `nvm`. From the Coqpilot root directory, execute: ```bash -make +nvm use ``` -Some tests depend on the small coq project, that is expected to be built. After that run: +If you prefer not to use `nvm`, ensure you install the Node.js version specified in the [`.nvmrc`](.nvmrc) file by any other method you prefer. + +Once Node.js is installed, the remaining setup will be handled by the `npm` package manager. Run the following commands: ```bash -npm run test +npm install +npm run compile ``` -Otherwise, if you do not want to build that small project, you can run: +To run the extension from the vscode, you can press `F5` or click on `Run extension` in the `Run and Debug` section. It will open a new window with the extension running. + +To run all tests properly (i.e. with rebuilding the resources and the code first), execute the following task: ```bash -npm run test-ci +npm run clean-test ``` To run specific tests, you can use `npm run test -- -g="grep pattern"`. @@ -122,9 +125,9 @@ This extension contributes the following settings: * `coqpilot.contextTheoremsRankerType` : The type of theorems ranker that will be used to select theorems for proof generation (when context is smaller than taking all of them). Either randomly, by Jacard index (similarity metric) or by distance from the theorem, with the currently observed admit. * `coqpilot.loggingVerbosity` : Verbosity of the logs. Could be `info`, `debug`. -* `coqpilot.openAiModelsParameters`, `coqpilot.predefinedProofsModelsParameters`, `coqpilot.grazieModelsParameters` and `coqpilot.lmStudioModelsParameters`: +* `coqpilot.predefinedProofsModelsParameters`, `coqpilot.openAiModelsParameters`, `coqpilot.grazieModelsParameters` and `coqpilot.lmStudioModelsParameters`: -Each of these settings are modified in `settings.json` and contain an array of models from this service. Each model will be used for generation independantly. Multiple models for a single service could be defined. For example, you can define parameters for two open-ai gpt models. One would be using `gpt-3.5` and the other one `gpt-4`. CoqPilot will first try to generate proofs using the first model, and if it fails, it will try the second one. This way coqpilot iterates over all services (currently 4 of them) and for each service it iterates over all models. +Each of these settings are modified in `settings.json` and contain an array of models from this service. Each model will be used for generation independantly. Multiple models for a single service could be defined. For example, you can define parameters for two open-ai gpt models. One would be using `gpt-3.5` and the other one `gpt-4`. CoqPilot will first try to generate proofs using the first model, and if it doesn't succeed, it will try the second one. This way coqpilot iterates over all services (currently 4 of them) and for each service it iterates over all models. ## Guide to Model Configuration @@ -151,7 +154,7 @@ The simplest service to configure is `predefinedProofs`: { "coqpilot.predefinedProofsModelsParameters": [ { - "modelName": "Any name", + "modelId": "predefined proofs", "tactics": [ "reflexivity.", "simpl. reflexivity.", @@ -161,19 +164,20 @@ The simplest service to configure is `predefinedProofs`: ] } ``` -Model name here is only used for convinience inside code, so may be any string. +The `modelId` property may be any string you like, but it should be unique for each model. This way, CoqPilot will be able to correctly tell you which model might have configuration issues. The most commonly used service is `open-ai` (`grazie` and `lmStudio` are configured very similarly). ```json { "coqpilot.openAiModelsParameters": [ { + "modelId": "openai-gpt-3.5", + "modelName": "gpt-3.5-turbo-0301", "temperature": 1, "apiKey": "***your-api-key***", - "modelName": "gpt-3.5-turbo-0301", "choices": 10, "systemPrompt": "Generate proof...", - "newMessageMaxTokens": 2000, + "maxTokensToGenerate": 2000, "tokensLimit": 4096, "multiroundProfile": { "maxRoundsNumber": 1, @@ -203,7 +207,7 @@ To run benchmarks on some project, apart from installing and building CoqPilot m git submodule init git submodule update ``` -After that, you need to build the projects. And be careful, the actively maintained way to build this projects is `nix`. Moreover, when adding your own projects, make sure that they are built using `coq-8.19.0`. +After that, you need to build the projects. Be careful, the actively maintained way to build this projects is `nix`. Moreover, when adding your own projects, make sure that they are built using `coq-8.19.0`. First things first, the process of running the benchmark is not perfectly automated yet. We are working on it. For now, one project (one unit containing nix environment) shall be ran at a time. Let's say you are going to run the benchmark on the `imm` project. You will have to do the following: @@ -228,7 +232,7 @@ First things first, the process of running the benchmark is not perfectly automa nix-build nix-shell ``` -4. Make sure the `_CoqProject` was successfully generated in the root of your mroject. Return to the project root not exiting the nix-shell. Run the benchmark: +4. Make sure the `_CoqProject` was successfully generated in the root of your project. Return to the project root not exiting the nix-shell. Run the benchmark: ```bash cd ../../ npm run benchmark diff --git a/package-lock.json b/package-lock.json index f90ba9cf..bf66ec23 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "coqpilot", - "version": "2.1.0", + "version": "2.2.0", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "coqpilot", - "version": "2.1.0", + "version": "2.2.0", "dependencies": { "@codemirror/autocomplete": "^6.9.1", "ajv": "^8.12.0", @@ -35,6 +35,7 @@ "@types/glob": "^8.1.0", "@types/mocha": "^10.0.1", "@types/node": "20.2.5", + "@types/tmp": "^0.2.6", "@types/vscode": "^1.82.0", "@types/yargs": "^17.0.24", "@typescript-eslint/eslint-plugin": "^5.62.0", @@ -47,6 +48,7 @@ "glob": "^8.1.0", "mocha": "^10.2.0", "prettier": "^3.2.5", + "tmp": "^0.2.3", "typescript": "^5.3.3" }, "engines": { @@ -836,6 +838,12 @@ "integrity": "sha512-cJRQXpObxfNKkFAZbJl2yjWtJCqELQIdShsogr1d2MilP8dKD9TE/nEKHkJgUNHdGKCQaf9HbIynuV2csLGVLg==", "dev": true }, + "node_modules/@types/tmp": { + "version": "0.2.6", + "resolved": "https://registry.npmjs.org/@types/tmp/-/tmp-0.2.6.tgz", + "integrity": "sha512-chhaNf2oKHlRkDGt+tiKE2Z5aJ6qalm7Z9rlLdBwmOiAAf09YQvvoLXjWK4HWPF1xU/fqvMgfNfpVoBscA/tKA==", + "dev": true + }, "node_modules/@types/vscode": { "version": "1.82.0", "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.82.0.tgz", @@ -6226,6 +6234,15 @@ "resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.13.tgz", "integrity": "sha512-JaL9ZnvTbGFMDIBeGdVkLt4qWTeCPw+n7Ock+wceAGRenuHA6nOOvMJFliNDyXsjg2osGKJWsXtO2xc74VxyDw==" }, + "node_modules/tmp": { + "version": "0.2.3", + "resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.3.tgz", + "integrity": "sha512-nZD7m9iCPC5g0pYmcaxogYKggSfLsdxl8of3Q/oIbqCqLLIO9IAF0GWjX1z9NZRHPiXv8Wex4yDCaZsgEw0Y8w==", + "dev": true, + "engines": { + "node": ">=14.14" + } + }, "node_modules/to-fast-properties": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/to-fast-properties/-/to-fast-properties-2.0.0.tgz", @@ -7169,6 +7186,12 @@ "integrity": "sha512-cJRQXpObxfNKkFAZbJl2yjWtJCqELQIdShsogr1d2MilP8dKD9TE/nEKHkJgUNHdGKCQaf9HbIynuV2csLGVLg==", "dev": true }, + "@types/tmp": { + "version": "0.2.6", + "resolved": "https://registry.npmjs.org/@types/tmp/-/tmp-0.2.6.tgz", + "integrity": "sha512-chhaNf2oKHlRkDGt+tiKE2Z5aJ6qalm7Z9rlLdBwmOiAAf09YQvvoLXjWK4HWPF1xU/fqvMgfNfpVoBscA/tKA==", + "dev": true + }, "@types/vscode": { "version": "1.82.0", "resolved": "https://registry.npmjs.org/@types/vscode/-/vscode-1.82.0.tgz", @@ -10866,6 +10889,12 @@ "resolved": "https://registry.npmjs.org/tiktoken/-/tiktoken-1.0.13.tgz", "integrity": "sha512-JaL9ZnvTbGFMDIBeGdVkLt4qWTeCPw+n7Ock+wceAGRenuHA6nOOvMJFliNDyXsjg2osGKJWsXtO2xc74VxyDw==" }, + "tmp": { + "version": "0.2.3", + "resolved": "https://registry.npmjs.org/tmp/-/tmp-0.2.3.tgz", + "integrity": "sha512-nZD7m9iCPC5g0pYmcaxogYKggSfLsdxl8of3Q/oIbqCqLLIO9IAF0GWjX1z9NZRHPiXv8Wex4yDCaZsgEw0Y8w==", + "dev": true + }, "to-fast-properties": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/to-fast-properties/-/to-fast-properties-2.0.0.tgz", diff --git a/package.json b/package.json index 68146e0c..a7da5517 100644 --- a/package.json +++ b/package.json @@ -8,7 +8,7 @@ "url": "https://github.com/K-dizzled/coqpilot" }, "publisher": "JetBrains-Research", - "version": "2.1.0", + "version": "2.2.0", "engines": { "vscode": "^1.82.0" }, @@ -37,15 +37,15 @@ "commands": [ { "command": "coqpilot.perform_completion_under_cursor", - "title": "Coqpilot: Try to generate proof for the goal under the cursor." + "title": "Coqpilot: Try to generate proof for the goal under the cursor" }, { "command": "coqpilot.perform_completion_for_all_admits", - "title": "Coqpilot: Try to prove all holes (admitted goals) in the current file." + "title": "Coqpilot: Try to prove all holes (admitted goals) in the current file" }, { "command": "coqpilot.perform_completion_in_selection", - "title": "Coqpilot: Try to prove holes (admitted goals) in selection." + "title": "Coqpilot: Try to prove holes (admitted goals) in the selection" } ], "menus": { @@ -62,44 +62,82 @@ "type": "object", "title": "CoqPilot", "properties": { + "coqpilot.predefinedProofsModelsParameters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "modelId": { + "type": "string", + "markdownDescription": "Unique identifier of this model to distinguish it from others. Could be any string.", + "default": "predefined-auto" + }, + "tactics": { + "type": "array", + "items": { + "type": "string" + }, + "description": "List of tactics to try to complete a hole with. Commands in the list must be valid Coq commands available in your environment.", + "default": [ + "auto." + ] + } + } + }, + "default": [ + { + "modelId": "predefined-auto", + "tactics": [ + "auto." + ] + } + ], + "markdownDescription": "List of configurations with sets of predefined proofs. CoqPilot will check these proofs when searching for completion.", + "order": 0 + }, "coqpilot.openAiModelsParameters": { "type": "array", "items": { "type": "object", "properties": { + "modelId": { + "type": "string", + "markdownDescription": "Unique identifier of this model to distinguish it from others. Could be any string.", + "default": "openai-gpt-3.5" + }, "modelName": { "type": "string", - "markdownDescription": "The model to use from the open-ai platform: \n * gpt-4 \n * gpt-4-0314 \n * gpt-4-0613 \n * gpt-4-32k \n * gpt-4-32k-0314 \n * gpt-4-32k-0613 \n * gpt-3.5-turbo \n * gpt-3.5-turbo-16k \n * gpt-3.5-turbo-0301 \n * gpt-3.5-turbo-0613 \n * gpt-3.5-turbo-16k-0613", + "markdownDescription": "Model to use from the OpenAI platform. List of models known to Coqpilot: \n * gpt-4o \n * gpt-4o-2024-05-13 \n * gpt-4-turbo \n * gpt-4-turbo-2024-04-09 \n * gpt-4-turbo-preview \n * gpt-4-0125-preview \n * gpt-4-1106-preview \n * gpt-4-vision-preview \n * gpt-4-1106-vision-preview \n * gpt-4 \n * gpt-4-0314 \n * gpt-4-0613 \n * gpt-4-32k \n * gpt-4-32k-0314 \n * gpt-4-32k-0613 \n * gpt-3.5-turbo-0125 \n * gpt-3.5-turbo \n * gpt-3.5-turbo-1106 \n * gpt-3.5-turbo-instruct \n * gpt-3.5-turbo-16k \n * gpt-3.5-turbo-16k-0613 \n * gpt-3.5-turbo-0613 \n * gpt-3.5-turbo-0301", "default": "gpt-3.5-turbo-0301" }, "temperature": { "type": "number", - "description": "The temperature of the open-ai model.", + "description": "Temperature of the OpenAI model.", "default": 1 }, "apiKey": { "type": "string", - "description": "An `open-ai` api key. Is used to communicate with the open-ai api. You can get one [here](https://platform.openai.com/account/api-keys).", + "description": "Api key to communicate with the OpenAi api. You can get one [here](https://platform.openai.com/account/api-keys).", "default": "None" }, "choices": { "type": "number", - "description": "How many proof attempts should be generated for one theorem.", + "description": "Number of attempts to generate proof for one hole with this model. All attempts are made as a single request, so this parameter should not have a significant impact on performance. However, more choices mean more tokens spent on generation.", "default": 15 }, "systemPrompt": { "type": "string", - "description": "A prompt for the open-ai model.", + "description": "Prompt for the OpenAI model to begin a chat with. It is sent as a system message, which means it has more impact than other messages.", "default": "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'." }, - "newMessageMaxTokens": { + "maxTokensToGenerate": { "type": "number", - "description": "How many tokens is allowed to be generated by the model.", - "default": 2000 + "description": "Number of tokens that the model is allowed to generate as a response message (i.e. message with proof). For known models, Coqpilot provides a recommended default value, but it can be customized for more advanced proof generation. The default value is the maximum allowed value for the model if it takes no more than half of `tokensLimit`, otherwise the minimum of half of `tokensLimit` and 4096.", + "default": 2048 }, "tokensLimit": { "type": "number", - "description": "The total length of input tokens and generated tokens. Is determined by the model. For open-ai models could be found [here](https://platform.openai.com/docs/models/).", + "description": "Total length of input and generated tokens, it is determined by the model. For known models, Coqpilot provides a recommended default value (the maximum model context length), but it can be customized for more advanced proof generation.", "default": 4096 }, "multiroundProfile": { @@ -107,36 +145,37 @@ "properties": { "maxRoundsNumber": { "type": "number", - "description": "The maximum number of rounds for the multiround completion.", + "description": "Maximum number of rounds to generate and further fix the proof. Default value is 1, which means each proof will be only generated, but not fixed.", "default": 1 }, "proofFixChoices": { "type": "number", - "description": "How many proof fixes should be generated for one proof.", + "description": "Number of attempts to generate a proof fix for each proof in one round. Warning: increasing `proofFixChoices` can lead to exponential growth in generation requests if `maxRoundsNumber` is relatively large.", "default": 1 }, "proofFixPrompt": { "type": "string", - "description": "A prompt for the fix request message.", - "default": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: '${diagnostic}'. Please, fix the proof." + "description": "Prompt for the proof-fix request that will be sent as a user chat message in response to an incorrect proof. It may include the `${diagnostic}` substring, which will be replaced by the actual compiler diagnostic.", + "default": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof." } }, "default": { "maxRoundsNumber": 1, "proofFixChoices": 1, - "proofFixPrompt": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: '${diagnostic}'. Please, fix the proof." + "proofFixPrompt": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof." } } } }, "default": [ { + "modelId": "openai-gpt-3.5", + "modelName": "gpt-3.5-turbo-0301", "temperature": 1, "apiKey": "None", - "modelName": "gpt-3.5-turbo-0301", "choices": 15, "systemPrompt": "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'.", - "newMessageMaxTokens": 2000, + "maxTokensToGenerate": 2048, "tokensLimit": 4096, "multiroundProfile": { "maxRoundsNumber": 1, @@ -145,42 +184,47 @@ } } ], - "markdownDescription": "A list of parameters for open-ai models. Each object represents a single model's configuration. Each model will be fetched for completion independently in the order they are listed.", - "order": 0 + "markdownDescription": "List of configurations for OpenAI models. Each configuration will be fetched for completions independently in the order they are listed.", + "order": 1 }, "coqpilot.grazieModelsParameters": { "type": "array", "items": { "type": "object", "properties": { + "modelId": { + "type": "string", + "markdownDescription": "Unique identifier of this model to distinguish it from others. Could be any string.", + "default": "openai-gpt-3.5" + }, "modelName": { "type": "string", - "markdownDescription": "The model to use from the grazie platform: \n * openai-gpt-4 \n * openai-chat-gpt \n * grazie-chat-llama-v2-7b \n * grazie-chat-llama-v2-13b \n * grazie-chat-zephyr-7b \n * qwen-turbo \n * qwen-plus", + "markdownDescription": "Model to use from the Grazie platform: \n * openai-gpt-4 \n * openai-chat-gpt \n * grazie-chat-llama-v2-7b \n * grazie-chat-llama-v2-13b \n * grazie-chat-zephyr-7b \n * qwen-turbo \n * qwen-plus", "default": "openai-gpt-4" }, "apiKey": { "type": "string", - "description": "`Grazie` api key. Now available for JetBrains employees only.", + "description": "Api key to communicate with the Grazie api. Now available for JetBrains employees only.", "default": "None" }, "choices": { "type": "number", - "description": "How many proof attempts should be generated for one theorem.", + "description": "Number of attempts to generate proof for one hole with this model.", "default": 15 }, "systemPrompt": { "type": "string", - "description": "A prompt for the grazie model.", + "description": "Prompt for the Grazie model to begin chat with. It is sent as a system message, which means it has more impact than other messages.", "default": "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'." }, - "newMessageMaxTokens": { + "maxTokensToGenerate": { "type": "number", - "description": "How many tokens is allowed to be generated by the model.", - "default": 2000 + "description": "Number of tokens that the model is allowed to generate as a response message (i.e. message with proof).", + "default": 1024 }, "tokensLimit": { "type": "number", - "description": "The total length of input tokens and generated tokens. Is determined by the model. For open-ai models could be found [here](https://platform.openai.com/docs/models/).", + "description": "Total length of input and generated tokens, it is determined by the model. For OpenAI models, tokens limits could be found [here](https://platform.openai.com/docs/models/).", "default": 4096 }, "multiroundProfile": { @@ -188,103 +232,70 @@ "properties": { "maxRoundsNumber": { "type": "number", - "description": "The maximum number of rounds for the multiround completion.", + "description": "Maximum number of rounds to generate and further fix the proof. Default value is 1, which means each proof will be only generated, but not fixed.", "default": 1 }, "proofFixChoices": { "type": "number", - "description": "How many proof fixes should be generated for one proof.", + "description": "Number of attempts to generate a proof fix for each proof in one round. Warning: increasing `proofFixChoices` can lead to exponential growth in generation requests if `maxRoundsNumber` is relatively large.", "default": 1 }, "proofFixPrompt": { "type": "string", - "description": "A prompt for the fix request message.", - "default": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: '${diagnostic}'. Please, fix the proof." + "description": "Prompt for the proof-fix request that will be sent as a user chat message in response to an incorrect proof. It may include the `${diagnostic}` substring, which will be replaced by the actual compiler diagnostic.", + "default": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof." } }, "default": { "maxRoundsNumber": 1, "proofFixChoices": 1, - "proofFixPrompt": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: '${diagnostic}'. Please, fix the proof." + "proofFixPrompt": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof." } } } }, "default": [], - "markdownDescription": "Now only available in beta for JetBrains employees. A list of parameters for grazie models. Each object represents a single model's configuration. Each model will be fetched for completion independently in the order they are listed.", + "markdownDescription": "Now available in beta for JetBrains employees only. List of configurations for Grazie models. Each configuration will be fetched for completions independently in the order they are listed.", "order": 2 }, - "coqpilot.predefinedProofsModelsParameters": { - "type": "array", - "items": { - "type": "object", - "properties": { - "modelName": { - "type": "string", - "markdownDescription": "Doesn't make any change. Just an identifier for the model.", - "default": "das-auto-proofs" - }, - "tactics": { - "type": "array", - "items": { - "type": "string" - }, - "description": "A list of tactics that would also be used to try generating proofs. Commands in the list must be valid coq commands available in your environment.", - "default": [ - "auto." - ] - } - } - }, - "default": [ - { - "modelName": "das-auto-proofs", - "tactics": [ - "auto." - ] - } - ], - "markdownDescription": "A list where each object represents a single model configuration. Here each model is configured with a set of predefined proofs, which coqpilot should try when searching for completion.", - "order": 1 - }, "coqpilot.lmStudioModelsParameters": { "type": "array", "items": { "type": "object", "properties": { - "modelName": { + "modelId": { "type": "string", - "markdownDescription": "Doesn't make any change. Just an identifier for the model.", + "markdownDescription": "Unique identifier of this model to distinguish it from others. Could be any string.", "default": "lm-studio" }, "temperature": { "type": "number", - "description": "The temperature of the model.", + "description": "Temperature of the LM Studio model.", "default": 1 }, "port": { "type": "number", - "description": "A port on which you have launched the LM studio.", + "description": "Port on which LM Studio is launched.", "default": 1234 }, "choices": { "type": "number", - "description": "How many proof attempts should be generated for one theorem.", + "description": "Number of attempts to generate proof for one hole with this model.", "default": 15 }, "systemPrompt": { "type": "string", - "description": "A prompt for the lm-studio model.", + "description": "Prompt for the LM Studio model to begin chat with. It is sent as a system message, which means it has more impact than other messages.", "default": "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'." }, - "newMessageMaxTokens": { + "maxTokensToGenerate": { "type": "number", - "description": "How many tokens is allowed to be generated by the model.", - "default": 2000 + "description": "Number of tokens that the model is allowed to generate as a response message (i.e. message with proof).", + "default": 1024 }, "tokensLimit": { "type": "number", - "description": "The total length of input tokens and generated tokens. Is determined by the model. For open-ai models could be found [here](https://platform.openai.com/docs/models/).", + "description": "Total length of input and generated tokens, usually it is determined by the model.", "default": 2048 }, "multiroundProfile": { @@ -292,30 +303,30 @@ "properties": { "maxRoundsNumber": { "type": "number", - "description": "The maximum number of rounds for the multiround completion.", + "description": "Maximum number of rounds to generate and further fix the proof. Default value is 1, which means each proof will be only generated, but not fixed.", "default": 1 }, "proofFixChoices": { "type": "number", - "description": "How many proof fixes should be generated for one proof.", + "description": "Number of attempts to generate a proof fix for each proof in one round. Warning: increasing `proofFixChoices` can lead to exponential growth in generation requests if `maxRoundsNumber` is relatively large.", "default": 1 }, "proofFixPrompt": { "type": "string", - "description": "A prompt for the fix request message.", - "default": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: '${diagnostic}'. Please, fix the proof." + "description": "Prompt for the proof-fix request that will be sent as a user chat message in response to an incorrect proof. It may include the `${diagnostic}` substring, which will be replaced by the actual compiler diagnostic.", + "default": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof." } }, "default": { "maxRoundsNumber": 1, "proofFixChoices": 1, - "proofFixPrompt": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: '${diagnostic}'. Please, fix the proof." + "proofFixPrompt": "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof." } } } }, "default": [], - "markdownDescription": "Configuration of models which fetch completions from locally running LLM inside the [LM studio](https://lmstudio.ai).", + "markdownDescription": "List of configurations that fetch completions from a locally running LLM inside [LM Studio](https://lmstudio.ai).", "order": 3 }, "coqpilot.contextTheoremsRankerType": { @@ -360,10 +371,11 @@ "format": "prettier --write \"src/**/*.{ts,js}\" && eslint \"src/**/*.{ts,js}\" --ext .ts --fix", "pretest": "npm run compile && npm run lint", "test": "node ./out/test/runTest.js", - "test-ci": "npm test -- -g=\"--non-ci\" -i=true", "clean": "rm -rf out", - "prebenchmark": "npm run clean", - "benchmark": "npm run test -- -g='Benchmark'" + "rebuild-test-resources": "cd ./src/test/resources/coqProj && make clean && make", + "preclean-test": "npm run clean && npm run rebuild-test-resources && npm run compile && npm run lint", + "clean-test": "node ./out/test/runTest.js", + "benchmark": "npm run clean-test -- -g='Benchmark'" }, "devDependencies": { "@trivago/prettier-plugin-sort-imports": "^4.3.0", @@ -373,6 +385,7 @@ "@types/glob": "^8.1.0", "@types/mocha": "^10.0.1", "@types/node": "20.2.5", + "@types/tmp": "^0.2.6", "@types/vscode": "^1.82.0", "@types/yargs": "^17.0.24", "@typescript-eslint/eslint-plugin": "^5.62.0", @@ -385,6 +398,7 @@ "glob": "^8.1.0", "mocha": "^10.2.0", "prettier": "^3.2.5", + "tmp": "^0.2.3", "typescript": "^5.3.3" }, "dependencies": { diff --git a/src/coqLsp/coqLspClient.ts b/src/coqLsp/coqLspClient.ts index 04be3830..f2ccfea0 100644 --- a/src/coqLsp/coqLspClient.ts +++ b/src/coqLsp/coqLspClient.ts @@ -175,7 +175,7 @@ export class CoqLspClient implements CoqLspClientInterface { ); const goal = goals?.goals?.goals?.shift() ?? undefined; if (!goal) { - return new CoqLspError("No goals at point."); + return new CoqLspError("no goals at point"); } return goal; @@ -239,7 +239,7 @@ export class CoqLspClient implements CoqLspClientInterface { pendingDiagnostic || awaitedDiagnostics === undefined ) { - throw new Error("Coq-lsp did not respond in time"); + throw new CoqLspError("coq-lsp did not respond in time"); } return this.filterDiagnostics( diff --git a/src/coqParser/parseCoqFile.ts b/src/coqParser/parseCoqFile.ts index bdce1a8c..205446a2 100644 --- a/src/coqParser/parseCoqFile.ts +++ b/src/coqParser/parseCoqFile.ts @@ -22,7 +22,7 @@ export async function parseCoqFile( }) .catch((error) => { throw new CoqParsingError( - `Failed to parse file with Error: ${error.message}` + `failed to parse file with Error: ${error.message}` ); }); } @@ -30,7 +30,7 @@ export async function parseCoqFile( export class CoqParsingError extends Error { constructor( public message: string, - public data?: any // eslint-disable-line @typescript-eslint/no-explicit-any + public data?: any ) { super(message); } @@ -41,7 +41,7 @@ function parseFlecheDocument( textLines: string[] ): Theorem[] { if (doc === null) { - throw new Error("Could not parse file"); + throw Error("could not parse file"); } const theorems: Theorem[] = []; @@ -75,7 +75,7 @@ function parseFlecheDocument( ) ); } else if (!nextExprVernac) { - throw new CoqParsingError("Unable to parse proof."); + throw new CoqParsingError("unable to parse proof"); } else if ( ![ Vernacexpr.VernacProof, @@ -123,7 +123,7 @@ function getTheoremName(expr: any): string { try { return expr[2][0][0][0]["v"][1]; } catch (error) { - throw new CoqParsingError("Invalid theorem name"); + throw new CoqParsingError("invalid theorem name"); } } @@ -131,7 +131,7 @@ function getDefinitionName(expr: any): string { try { return expr[2][0]["v"][1][1]; } catch (error) { - throw new CoqParsingError("Invalid definition name"); + throw new CoqParsingError("invalid definition name"); } } @@ -142,7 +142,7 @@ function getName(expr: any): string { case Vernacexpr.VernacStartTheoremProof: return getTheoremName(expr); default: - throw new CoqParsingError("Invalid name"); + throw new CoqParsingError(`invalid name for expression: "${expr}"`); } } @@ -213,7 +213,7 @@ function parseProof( const vernacType = getVernacexpr(getExpr(span)); if (!vernacType) { throw new CoqParsingError( - "Unable to derive the vernac type of the sentance" + "unable to derive the vernac type of the sentence" ); } @@ -254,7 +254,7 @@ function parseProof( } if (!proven || endPos === null) { - throw new CoqParsingError("Invalid or incomplete proof."); + throw new CoqParsingError("invalid or incomplete proof"); } const proofObj = new TheoremProof( diff --git a/src/core/completionGenerator.ts b/src/core/completionGenerator.ts index a321f02c..613b57d7 100644 --- a/src/core/completionGenerator.ts +++ b/src/core/completionGenerator.ts @@ -3,13 +3,14 @@ import { Position } from "vscode-languageclient"; import { LLMSequentialIterator } from "../llm/llmIterator"; import { LLMServices } from "../llm/llmServices"; import { GeneratedProof } from "../llm/llmServices/llmService"; +import { ModelsParams } from "../llm/llmServices/modelParams"; import { ProofGenerationContext } from "../llm/proofGenerationContext"; -import { UserModelsParams } from "../llm/userModelParams"; import { Goal, Hyp, PpString } from "../coqLsp/coqLspTypes"; import { Theorem } from "../coqParser/parsedTypes"; import { EventLogger } from "../logging/eventLogger"; +import { stringifyAnyValue } from "../utils/printers"; import { ContextTheoremsRanker } from "./contextTheoremRanker/contextTheoremsRanker"; import { @@ -34,7 +35,7 @@ export interface SourceFileEnvironment { export interface ProcessEnvironment { coqProofChecker: CoqProofChecker; - modelsParams: UserModelsParams; + modelsParams: ModelsParams; services: LLMServices; /** * If `theoremRanker` is not provided, the default one will be used: @@ -46,7 +47,7 @@ export interface ProcessEnvironment { export interface GenerationResult {} export class SuccessGenerationResult implements GenerationResult { - constructor(public data: any) {} + constructor(public data: string) {} } export class FailureGenerationResult implements GenerationResult { @@ -57,9 +58,9 @@ export class FailureGenerationResult implements GenerationResult { } export enum FailureGenerationStatus { - excededTimeout, - exception, - searchFailed, + TIMEOUT_EXCEEDED, + ERROR_OCCURRED, + SEARCH_FAILED, } export async function generateCompletion( @@ -70,8 +71,8 @@ export async function generateCompletion( ): Promise { const context = buildProofGenerationContext( completionContext, - sourceFileEnvironment, - processEnvironment + sourceFileEnvironment.fileTheorems, + processEnvironment.theoremRanker ); eventLogger?.log( "proof-gen-context-create", @@ -136,18 +137,32 @@ export async function generateCompletion( } return new FailureGenerationResult( - FailureGenerationStatus.searchFailed, + FailureGenerationStatus.SEARCH_FAILED, "No valid completions found" ); } catch (e: any) { + const error = e as Error; + if (error === null) { + console.error( + `Object was thrown during completion generation: ${e}` + ); + return new FailureGenerationResult( + FailureGenerationStatus.ERROR_OCCURRED, + `please report this crash by opening an issue in the Coqpilot GitHub repository: object was thrown as error, ${stringifyAnyValue(e)}` + ); + } else { + console.error( + `Error occurred during completion generation:\n${error.stack ?? error}` + ); + } if (e instanceof CoqLspTimeoutError) { return new FailureGenerationResult( - FailureGenerationStatus.excededTimeout, + FailureGenerationStatus.TIMEOUT_EXCEEDED, e.message ); } else { return new FailureGenerationResult( - FailureGenerationStatus.exception, + FailureGenerationStatus.ERROR_OCCURRED, e.message ); } @@ -190,7 +205,7 @@ export async function checkAndFixProofs( "Proofs were fixed", fixedProofs.map( (proof) => - `New proof: ${proof.proof()} with version ${proof.versionNumber()}\n Previous version: ${JSON.stringify(proof.proofVersions.slice(-2))}` + `New proof: "${proof.proof()}" with version ${proof.versionNumber()}\n Previous version: ${stringifyAnyValue(proof.proofVersions.slice(-2))}` ) ); return fixedProofs; // prepare to a new iteration @@ -264,7 +279,7 @@ function getFirstValidProof( return undefined; } -function prepareProofToCheck(proof: string) { +export function prepareProofToCheck(proof: string) { // 1. Remove backtiks -- coq-lsp dies from backticks randomly let preparedProof = proof.replace(/`/g, ""); @@ -290,23 +305,21 @@ function goalToTargetLemma(proofGoal: Goal): string { return `Lemma helper_theorem ${theoremIndeces} :\n ${auxTheoremConcl}.`; } -function buildProofGenerationContext( +export function buildProofGenerationContext( completionContext: CompletionContext, - sourceFileEnvironment: SourceFileEnvironment, - processEnvironment: ProcessEnvironment + fileTheorems: Theorem[], + theoremRanker?: ContextTheoremsRanker ): ProofGenerationContext { const rankedTheorems = - processEnvironment.theoremRanker?.rankContextTheorems( - sourceFileEnvironment.fileTheorems, - completionContext - ) ?? sourceFileEnvironment.fileTheorems; + theoremRanker?.rankContextTheorems(fileTheorems, completionContext) ?? + fileTheorems; return { contextTheorems: rankedTheorems, completionTarget: goalToTargetLemma(completionContext.proofGoal), }; } -function getTextBeforePosition( +export function getTextBeforePosition( textLines: string[], position: Position ): string[] { diff --git a/src/core/contextTheoremRanker/JaccardIndexContextTheoremsRanker.ts b/src/core/contextTheoremRanker/JaccardIndexContextTheoremsRanker.ts index a8126881..ab5e7a22 100644 --- a/src/core/contextTheoremRanker/JaccardIndexContextTheoremsRanker.ts +++ b/src/core/contextTheoremRanker/JaccardIndexContextTheoremsRanker.ts @@ -10,7 +10,7 @@ import { ContextTheoremsRanker } from "./contextTheoremsRanker"; * the current goal context. Metric is calculated on the * concatenated hypothesis and conclusion. * - * J(A, B) = |A ∩ B| / |A ∪ B| + * ```J(A, B) = |A ∩ B| / |A ∪ B|``` */ export class JaccardIndexContextTheoremsRanker implements ContextTheoremsRanker diff --git a/src/core/coqProofChecker.ts b/src/core/coqProofChecker.ts index e4647932..91792b70 100644 --- a/src/core/coqProofChecker.ts +++ b/src/core/coqProofChecker.ts @@ -68,7 +68,7 @@ export class CoqProofChecker implements CoqProofCheckerInterface { }); } - private makeAuxFileName( + private buildAuxFileUri( sourceDirPath: string, holePosition: Position, unique: boolean = true @@ -98,7 +98,7 @@ export class CoqProofChecker implements CoqProofCheckerInterface { proofs: Proof[] ): Promise { // 1. Write the text to the aux file - const auxFileUri = this.makeAuxFileName( + const auxFileUri = this.buildAuxFileUri( sourceDirPath, prefixEndPosition ); diff --git a/src/core/inspectSourceFile.ts b/src/core/inspectSourceFile.ts index ad202483..d47d7243 100644 --- a/src/core/inspectSourceFile.ts +++ b/src/core/inspectSourceFile.ts @@ -83,7 +83,9 @@ export async function createSourceFileEnvironment( const fileText = readFileSync(fileUri.fsPath); const dirPath = getSourceFolderPath(fileUri); if (!dirPath) { - throw new Error("Unable to get source folder path"); + throw Error( + `unable to get source folder path from \`fileUri\`: ${fileUri}` + ); } return { diff --git a/src/extension/configReaders.ts b/src/extension/configReaders.ts new file mode 100644 index 00000000..be417eef --- /dev/null +++ b/src/extension/configReaders.ts @@ -0,0 +1,300 @@ +import Ajv, { DefinedError, JSONSchemaType } from "ajv"; +import { WorkspaceConfiguration, workspace } from "vscode"; + +import { LLMServices } from "../llm/llmServices"; +import { LLMService } from "../llm/llmServices/llmService"; +import { ModelParams, ModelsParams } from "../llm/llmServices/modelParams"; +import { SingleParamResolutionResult } from "../llm/llmServices/utils/paramsResolvers/abstractResolvers"; +import { + GrazieUserModelParams, + LMStudioUserModelParams, + OpenAiUserModelParams, + PredefinedProofsUserModelParams, + UserModelParams, + grazieUserModelParamsSchema, + lmStudioUserModelParamsSchema, + openAiUserModelParamsSchema, + predefinedProofsUserModelParamsSchema, +} from "../llm/userModelParams"; + +import { JaccardIndexContextTheoremsRanker } from "../core/contextTheoremRanker/JaccardIndexContextTheoremsRanker"; +import { ContextTheoremsRanker } from "../core/contextTheoremRanker/contextTheoremsRanker"; +import { DistanceContextTheoremsRanker } from "../core/contextTheoremRanker/distanceContextTheoremsRanker"; +import { RandomContextTheoremsRanker } from "../core/contextTheoremRanker/randomContextTheoremsRanker"; + +import { AjvMode, buildAjv } from "../utils/ajvErrorsHandling"; +import { stringifyAnyValue, stringifyDefinedValue } from "../utils/printers"; + +import { pluginId } from "./coqPilot"; +import { EditorMessages } from "./editorMessages"; +import { + SettingsValidationError, + showMessageToUserWithSettingsHint, + toSettingName, +} from "./settingsValidationError"; + +export function buildTheoremsRankerFromConfig(): ContextTheoremsRanker { + const workspaceConfig = workspace.getConfiguration(pluginId); + const rankerType = workspaceConfig.contextTheoremsRankerType; + switch (rankerType) { + case "distance": + return new DistanceContextTheoremsRanker(); + case "random": + return new RandomContextTheoremsRanker(); + case "jaccardIndex": + return new JaccardIndexContextTheoremsRanker(); + default: + throw new SettingsValidationError( + `unknown context theorems ranker type: ${rankerType}`, + EditorMessages.unknownContextTheoremsRanker, + "contextTheoremsRankerType" + ); + } +} + +export function readAndValidateUserModelsParams( + config: WorkspaceConfiguration, + llmServices: LLMServices +): ModelsParams { + /* + * Although the messages might become too verbose because of reporting all errors at once + * (unfortuantely, vscode notifications do not currently support formatting); + * we want the user to fix type-validation issues as soon as possible + * to move on to clearer messages and generating completions faster. + */ + const jsonSchemaValidator = buildAjv(AjvMode.COLLECT_ALL_ERRORS); + + const predefinedProofsUserParams: PredefinedProofsUserModelParams[] = + config.predefinedProofsModelsParameters.map((params: any) => + validateAndParseJson( + params, + predefinedProofsUserModelParamsSchema, + jsonSchemaValidator + ) + ); + const openAiUserParams: OpenAiUserModelParams[] = + config.openAiModelsParameters.map((params: any) => + validateAndParseJson( + params, + openAiUserModelParamsSchema, + jsonSchemaValidator + ) + ); + const grazieUserParams: GrazieUserModelParams[] = + config.grazieModelsParameters.map((params: any) => + validateAndParseJson( + params, + grazieUserModelParamsSchema, + jsonSchemaValidator + ) + ); + const lmStudioUserParams: LMStudioUserModelParams[] = + config.lmStudioModelsParameters.map((params: any) => + validateAndParseJson( + params, + lmStudioUserModelParamsSchema, + jsonSchemaValidator + ) + ); + + validateIdsAreUnique([ + ...predefinedProofsUserParams, + ...openAiUserParams, + ...grazieUserParams, + ...lmStudioUserParams, + ]); + validateApiKeysAreProvided(openAiUserParams, grazieUserParams); + + const modelsParams: ModelsParams = { + predefinedProofsModelParams: resolveParamsAndShowResolutionLogs( + llmServices.predefinedProofsService, + predefinedProofsUserParams + ), + openAiParams: resolveParamsAndShowResolutionLogs( + llmServices.openAiService, + openAiUserParams + ), + grazieParams: resolveParamsAndShowResolutionLogs( + llmServices.grazieService, + grazieUserParams + ), + lmStudioParams: resolveParamsAndShowResolutionLogs( + llmServices.lmStudioService, + lmStudioUserParams + ), + }; + + validateModelsArePresent([ + ...modelsParams.predefinedProofsModelParams, + ...modelsParams.openAiParams, + ...modelsParams.grazieParams, + ...modelsParams.lmStudioParams, + ]); + + return modelsParams; +} + +function validateAndParseJson( + json: any, + targetClassSchema: JSONSchemaType, + jsonSchemaValidator: Ajv +): T { + const instance: T = json as T; + const validate = jsonSchemaValidator.compile(targetClassSchema); + if (!validate(instance)) { + const settingsName = targetClassSchema.title; + if (settingsName === undefined) { + throw Error( + `specified \`targetClassSchema\` does not have \`title\`; while resolving json: ${stringifyAnyValue(json)}` + ); + } + const ajvErrors = validate.errors as DefinedError[]; + if (ajvErrors === null || ajvErrors === undefined) { + throw Error( + `validation with Ajv failed, but \`validate.errors\` are not defined; while resolving json: ${stringifyAnyValue(json)}` + ); + } + throw new SettingsValidationError( + `unable to validate json ${stringifyAnyValue(json)}: ${stringifyDefinedValue(validate.errors)}`, + EditorMessages.unableToValidateUserSettings( + settingsName, + ajvErrors, + ["oneOf"] // ignore additional boilerplate "oneOf" error, which appears if something is wrong with nested `multiroundProfile` + ), + settingsName + ); + } + return instance; +} + +function validateIdsAreUnique(allModels: UserModelParams[]) { + const modelIds = allModels.map((params) => params.modelId); + const uniqueModelIds = new Set(); + for (const modelId of modelIds) { + if (uniqueModelIds.has(modelId)) { + throw new SettingsValidationError( + `models' identifiers are not unique: several models have \`modelId: "${modelId}"\``, + EditorMessages.modelsIdsAreNotUnique(modelId) + ); + } else { + uniqueModelIds.add(modelId); + } + } +} + +function validateApiKeysAreProvided( + openAiUserParams: OpenAiUserModelParams[], + grazieUserParams: GrazieUserModelParams[] +) { + const buildApiKeyError = ( + serviceName: string, + serviceSettingsName: string + ) => { + return new SettingsValidationError( + `at least one of the ${serviceName} models has \`apiKey: "None"\``, + EditorMessages.apiKeyIsNotSet(serviceName), + `${pluginId}.${serviceSettingsName}ModelsParameters`, + "info" + ); + }; + + if (openAiUserParams.some((params) => params.apiKey === "None")) { + throw buildApiKeyError("Open Ai", "openAi"); + } + if (grazieUserParams.some((params) => params.apiKey === "None")) { + throw buildApiKeyError("Grazie", "grazie"); + } +} + +function validateModelsArePresent(allModels: T[]) { + if (allModels.length === 0) { + throw new SettingsValidationError( + "no models specified for proof generation", + EditorMessages.noValidModelsAreChosen, + pluginId, + "warning" + ); + } +} + +function resolveParamsAndShowResolutionLogs< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, +>( + llmService: LLMService, + inputParamsList: InputModelParams[] +): ResolvedModelParams[] { + const settingName = toSettingName(llmService); + const resolvedParamsList: ResolvedModelParams[] = []; + + for (const inputParams of inputParamsList) { + const resolutionResult = llmService.resolveParameters(inputParams); + + // notify user about errors (with full logs for failed parameters) and overrides + for (const paramLog of resolutionResult.resolutionLogs) { + if (paramLog.resultValue === undefined) { + // failed to resolve parameter + const resolutionHistory = buildResolutionHistory(paramLog); + showMessageToUserWithSettingsHint( + EditorMessages.modelConfiguredIncorrectly( + inputParams.modelId, + `${paramLog.isInvalidCause}${resolutionHistory}` + ), + "error", + settingName + ); + } else if ( + paramLog.overriden.wasPerformed && + paramLog.inputReadCorrectly.wasPerformed + ) { + // resolved parameter, but the user value was overriden + showMessageToUserWithSettingsHint( + EditorMessages.userValueWasOverriden( + inputParams.modelId, + paramLog.inputParamName ?? "", + paramLog.overriden.withValue, + paramLog.overriden.message + ), + "info", + settingName + ); + } + } + + if (resolutionResult.resolved !== undefined) { + resolvedParamsList.push(resolutionResult.resolved); + } + } + return resolvedParamsList; +} + +function buildResolutionHistory( + paramLog: SingleParamResolutionResult +): string { + const inputReadPerformed = paramLog.inputReadCorrectly.wasPerformed; + const overridePerformed = paramLog.overriden.wasPerformed; + const withDefaultPerformed = paramLog.resolvedWithDefault.wasPerformed; + + const onlySuccessfulRead = + inputReadPerformed && !overridePerformed && !withDefaultPerformed; + if (onlySuccessfulRead) { + return ""; + } + const inputRead = paramLog.inputReadCorrectly.wasPerformed + ? `read ${stringifyAnyValue(paramLog.inputReadCorrectly.withValue)}` + : "no input value read"; + const withOverride = paramLog.overriden.wasPerformed + ? `, overriden with ${stringifyAnyValue(paramLog.overriden.withValue)}` + : ""; + const withDefault = paramLog.resolvedWithDefault.wasPerformed + ? `, resolved with default ${stringifyAnyValue(paramLog.resolvedWithDefault.withValue)}` + : ""; + + const onlyFailedRead = + !inputReadPerformed && !overridePerformed && !withDefaultPerformed; + const anyResolutionActionPerformed = + overridePerformed || withDefaultPerformed; + return onlyFailedRead || anyResolutionActionPerformed + ? `; value's resolution: ${inputRead}${withOverride}${withDefault}` + : ""; +} diff --git a/src/extension/coqPilot.ts b/src/extension/coqPilot.ts index 6d425291..1909b8f9 100644 --- a/src/extension/coqPilot.ts +++ b/src/extension/coqPilot.ts @@ -1,31 +1,12 @@ -import Ajv, { JSONSchemaType } from "ajv"; import { ExtensionContext, ProgressLocation, TextEditor, - WorkspaceConfiguration, commands, window, workspace, } from "vscode"; -import { LLMServices } from "../llm/llmServices"; -import { GrazieService } from "../llm/llmServices/grazie/grazieService"; -import { LMStudioService } from "../llm/llmServices/lmStudio/lmStudioService"; -import { OpenAiService } from "../llm/llmServices/openai/openAiService"; -import { PredefinedProofsService } from "../llm/llmServices/predefinedProofs/predefinedProofsService"; -import { - GrazieUserModelParams, - LMStudioUserModelParams, - OpenAiUserModelParams, - PredefinedProofsUserModelParams, - UserModelsParams, - grazieUserModelParamsSchema, - lmStudioUserModelParamsSchema, - openAiUserModelParamsSchema, - predefinedProofsUserModelParamsSchema, -} from "../llm/userModelParams"; - import { CoqLspClient } from "../coqLsp/coqLspClient"; import { CoqLspConfig } from "../coqLsp/coqLspConfig"; @@ -40,80 +21,39 @@ import { ProcessEnvironment, SourceFileEnvironment, } from "../core/completionGenerator"; -import { JaccardIndexContextTheoremsRanker } from "../core/contextTheoremRanker/JaccardIndexContextTheoremsRanker"; -import { ContextTheoremsRanker } from "../core/contextTheoremRanker/contextTheoremsRanker"; -import { DistanceContextTheoremsRanker } from "../core/contextTheoremRanker/distanceContextTheoremsRanker"; -import { RandomContextTheoremsRanker } from "../core/contextTheoremRanker/randomContextTheoremsRanker"; import { CoqProofChecker } from "../core/coqProofChecker"; import { inspectSourceFile } from "../core/inspectSourceFile"; import { ProofStep } from "../coqParser/parsedTypes"; -import { EventLogger, Severity } from "../logging/eventLogger"; import { Uri } from "../utils/uri"; +import { + buildTheoremsRankerFromConfig, + readAndValidateUserModelsParams, +} from "./configReaders"; import { deleteTextFromRange, highlightTextInEditor, insertCompletion, } from "./documentEditor"; -import { - EditorMessages, - showApiKeyNotProvidedMessage, - showMessageToUser, - suggestAddingAuxFilesToGitignore, -} from "./editorMessages"; +import { suggestAddingAuxFilesToGitignore } from "./editGitignoreCommand"; +import { EditorMessages, showMessageToUser } from "./editorMessages"; +import { GlobalExtensionState } from "./globalExtensionState"; +import { subscribeToHandleLLMServicesEvents } from "./llmServicesEventsHandler"; import { positionInRange, toVSCodePosition, toVSCodeRange, } from "./positionRangeUtils"; +import { SettingsValidationError } from "./settingsValidationError"; import { cleanAuxFiles, hideAuxFiles } from "./tmpFilesCleanup"; -import VSCodeLogWriter from "./vscodeLogWriter"; export const pluginId = "coqpilot"; -export class GlobalExtensionState { - public readonly eventLogger: EventLogger = new EventLogger(); - public readonly logWriter: VSCodeLogWriter = new VSCodeLogWriter( - this.eventLogger, - this.parseLoggingVerbosity(workspace.getConfiguration(pluginId)) - ); - public readonly llmServices: LLMServices = { - openAiService: new OpenAiService(this.eventLogger), - grazieService: new GrazieService(this.eventLogger), - predefinedProofsService: new PredefinedProofsService(), - lmStudioService: new LMStudioService(this.eventLogger), - }; - - constructor() {} - - private parseLoggingVerbosity(config: WorkspaceConfiguration): Severity { - const verbosity = config.get("loggingVerbosity"); - switch (verbosity) { - case "info": - return Severity.INFO; - case "debug": - return Severity.DEBUG; - default: - throw new Error(`Unknown logging verbosity: ${verbosity}`); - } - } - - dispose(): void { - this.llmServices.openAiService.dispose(); - this.llmServices.grazieService.dispose(); - this.llmServices.predefinedProofsService.dispose(); - this.llmServices.lmStudioService.dispose(); - this.logWriter.dispose(); - } -} - export class CoqPilot { private readonly globalExtensionState: GlobalExtensionState; private readonly vscodeExtensionContext: ExtensionContext; - private readonly jsonSchemaValidator: Ajv; - constructor(vscodeExtensionContext: ExtensionContext) { hideAuxFiles(); suggestAddingAuxFilesToGitignore(); @@ -134,8 +74,6 @@ export class CoqPilot { this.performCompletionForAllAdmits.bind(this) ); - this.jsonSchemaValidator = new Ajv(); - this.vscodeExtensionContext.subscriptions.push(this); } @@ -159,28 +97,6 @@ export class CoqPilot { this.performSpecificCompletionsWithProgress((_hole) => true, editor); } - private checkUserProvidedApiKeys( - processEnvironment: ProcessEnvironment - ): boolean { - if ( - processEnvironment.modelsParams.openAiParams.some( - (params) => params.apiKey === "None" - ) - ) { - showApiKeyNotProvidedMessage("openai", pluginId); - return false; - } else if ( - processEnvironment.modelsParams.grazieParams.some( - (params) => params.apiKey === "None" - ) - ) { - showApiKeyNotProvidedMessage("grazie", pluginId); - return false; - } - - return true; - } - private async performSpecificCompletionsWithProgress( shouldCompleteHole: (hole: ProofStep) => boolean, editor: TextEditor @@ -197,11 +113,14 @@ export class CoqPilot { editor ); } catch (error) { - if (error instanceof UserSettingsValidationError) { - showMessageToUser(error.toString(), "error"); + if (error instanceof SettingsValidationError) { + error.showAsMessageToUser(); } else if (error instanceof Error) { - showMessageToUser(error.message, "error"); - console.error(error); + showMessageToUser( + EditorMessages.errorOccurred(error.message), + "error" + ); + console.error(`${error.stack ?? error}`); } } } @@ -219,20 +138,28 @@ export class CoqPilot { editor.document.uri.fsPath ); - if (!this.checkUserProvidedApiKeys(processEnvironment)) { - return; - } + const unsubscribeFromLLMServicesEventsCallback = + subscribeToHandleLLMServicesEvents( + this.globalExtensionState.llmServices, + this.globalExtensionState.eventLogger + ); - let completionPromises = completionContexts.map((completionContext) => { - return this.performSingleCompletion( - completionContext, - sourceFileEnvironment, - processEnvironment, - editor + try { + let completionPromises = completionContexts.map( + (completionContext) => { + return this.performSingleCompletion( + completionContext, + sourceFileEnvironment, + processEnvironment, + editor + ); + } ); - }); - await Promise.all(completionPromises); + await Promise.all(completionPromises); + } finally { + unsubscribeFromLLMServicesEventsCallback(); + } } private async performSingleCompletion( @@ -273,16 +200,16 @@ export class CoqPilot { highlightTextInEditor(completionRange); } else if (result instanceof FailureGenerationResult) { switch (result.status) { - case FailureGenerationStatus.excededTimeout: - showMessageToUser(EditorMessages.timeoutError, "info"); + case FailureGenerationStatus.TIMEOUT_EXCEEDED: + showMessageToUser(EditorMessages.timeoutExceeded, "info"); break; - case FailureGenerationStatus.exception: + case FailureGenerationStatus.ERROR_OCCURRED: showMessageToUser( - EditorMessages.exceptionError(result.message), + EditorMessages.errorOccurred(result.message), "error" ); break; - case FailureGenerationStatus.searchFailed: + case FailureGenerationStatus.SEARCH_FAILED: const completionLine = completionContext.prefixEndPosition.line + 1; showMessageToUser( @@ -313,7 +240,7 @@ export class CoqPilot { const coqLspServerConfig = CoqLspConfig.createServerConfig(); const coqLspClientConfig = CoqLspConfig.createClientConfig(); const client = new CoqLspClient(coqLspServerConfig, coqLspClientConfig); - const contextTheoremsRanker = this.buildTheoremsRankerFromConfig(); + const contextTheoremsRanker = buildTheoremsRankerFromConfig(); const coqProofChecker = new CoqProofChecker(client); const [completionContexts, sourceFileEnvironment] = @@ -325,8 +252,9 @@ export class CoqPilot { ); const processEnvironment: ProcessEnvironment = { coqProofChecker: coqProofChecker, - modelsParams: this.parseUserModelsParams( - workspace.getConfiguration(pluginId) + modelsParams: readAndValidateUserModelsParams( + workspace.getConfiguration(pluginId), + this.globalExtensionState.llmServices ), services: this.globalExtensionState.llmServices, theoremRanker: contextTheoremsRanker, @@ -335,71 +263,6 @@ export class CoqPilot { return [completionContexts, sourceFileEnvironment, processEnvironment]; } - private buildTheoremsRankerFromConfig(): ContextTheoremsRanker { - const workspaceConfig = workspace.getConfiguration(pluginId); - const rankerType = workspaceConfig.contextTheoremsRankerType; - - switch (rankerType) { - case "distance": - return new DistanceContextTheoremsRanker(); - case "random": - return new RandomContextTheoremsRanker(); - case "jaccardIndex": - return new JaccardIndexContextTheoremsRanker(); - default: - throw new Error( - `Unknown context theorems ranker type: ${rankerType}` - ); - } - } - - private parseUserModelsParams( - config: WorkspaceConfiguration - ): UserModelsParams { - const openAiParams: OpenAiUserModelParams[] = - config.openAiModelsParameters.map((params: any) => - this.validateAndParseJson(params, openAiUserModelParamsSchema) - ); - const grazieParams: GrazieUserModelParams[] = - config.grazieModelsParameters.map((params: any) => - this.validateAndParseJson(params, grazieUserModelParamsSchema) - ); - const predefinedProofsParams: PredefinedProofsUserModelParams[] = - config.predefinedProofsModelsParameters.map((params: any) => - this.validateAndParseJson( - params, - predefinedProofsUserModelParamsSchema - ) - ); - const lmStudioParams: LMStudioUserModelParams[] = - config.lmStudioModelsParameters.map((params: any) => - this.validateAndParseJson(params, lmStudioUserModelParamsSchema) - ); - - return { - openAiParams: openAiParams, - grazieParams: grazieParams, - predefinedProofsModelParams: predefinedProofsParams, - lmStudioParams: lmStudioParams, - }; - } - - private validateAndParseJson( - json: any, - targetClassSchema: JSONSchemaType - ): T { - const instance: T = json as T; - const validate = this.jsonSchemaValidator.compile(targetClassSchema); - if (!validate(instance)) { - throw new UserSettingsValidationError( - `Unable to validate json against the class: ${JSON.stringify(validate.errors)}`, - targetClassSchema.title ?? "Unknown" - ); - } - - return instance; - } - private registerEditorCommand( command: string, fn: (editor: TextEditor) => void @@ -416,16 +279,3 @@ export class CoqPilot { this.globalExtensionState.dispose(); } } - -class UserSettingsValidationError extends Error { - constructor( - message: string, - public readonly settingsName: string - ) { - super(message); - } - - toString(): string { - return `Unable to validate user settings for ${this.settingsName}. Please refer to the README for the correct settings format: https://github.com/JetBrains-Research/coqpilot/blob/main/README.md#guide-to-model-configuration.`; - } -} diff --git a/src/extension/editGitignoreCommand.ts b/src/extension/editGitignoreCommand.ts new file mode 100644 index 00000000..8df1b522 --- /dev/null +++ b/src/extension/editGitignoreCommand.ts @@ -0,0 +1,52 @@ +import { appendFile, existsSync, readFileSync } from "fs"; +import * as path from "path"; +import { window, workspace } from "vscode"; + +import { showMessageToUser } from "./editorMessages"; + +export async function suggestAddingAuxFilesToGitignore() { + const workspaceFolders = workspace.workspaceFolders; + if (!workspaceFolders) { + return; + } + + for (const folder of workspaceFolders) { + const gitIgnorePath = path.join(folder.uri.fsPath, ".gitignore"); + if (!existsSync(gitIgnorePath)) { + // .gitignore not found. Exit. + return; + } + + const data = readFileSync(gitIgnorePath, "utf8"); + const auxExt = "*_cp_aux.v"; + if (data.indexOf(auxExt) === -1) { + // Not found. Ask user if we should add it. + await window + .showInformationMessage( + 'Do you want to add "*_cp_aux.v" to .gitignore?', + "Yes", + "No" + ) + .then((choice) => { + if (choice === "Yes") { + const rule = `\n# Coqpilot auxiliary files\n${auxExt}`; + appendFile(gitIgnorePath, rule, (err) => { + if (err) { + showMessageToUser( + `Unexpected error writing to .gitignore: ${err.message}`, + "error" + ); + } else { + showMessageToUser( + 'Successfully added "*_cp_aux.v" to .gitignore', + "info" + ); + } + }); + } + }); + } else { + return; + } + } +} diff --git a/src/extension/editorMessages.ts b/src/extension/editorMessages.ts index 59c8dfe6..d9f6f921 100644 --- a/src/extension/editorMessages.ts +++ b/src/extension/editorMessages.ts @@ -1,100 +1,124 @@ -import { appendFile, existsSync, readFileSync } from "fs"; -import * as path from "path"; -import { commands, window, workspace } from "vscode"; +import { DefinedError } from "ajv"; +import { window } from "vscode"; + +import { Time } from "../llm/llmServices/utils/time"; + +import { ajvErrorsAsString } from "../utils/ajvErrorsHandling"; +import { stringifyAnyValue } from "../utils/printers"; export namespace EditorMessages { - export const timeoutError = - "Coqpilot: The proof checking process timed out. Please try again."; - export const noProofsForAdmit = (admitIdentifier: any) => - `Coqpilot failed to find a proof for the admit at line ${admitIdentifier}.`; - export const exceptionError = (errorMsg: string) => - "Coqpilot: An exception occured: " + errorMsg; + export const timeoutExceeded = + "The proof checking process timed out. Please try again."; + + export const noProofsForAdmit = (lineWithAdmitNumber: number) => + `Coqpilot failed to find a proof for the admit at line ${lineWithAdmitNumber}.`; + + export const errorOccurred = (errorMessage: string) => + `Coqpilot got an error: ${errorMessage}. Please make sure the environment is properly set and the plugin is configured correctly. For more information, see the README: https://github.com/JetBrains-Research/coqpilot/blob/main/README.md. If the error appears to be a bug, please report it by opening an issue in the Coqpilot GitHub repository.`; + + export const serviceBecameUnavailable = ( + serviceName: string, + errorMessage: string, + expectedTimeToBecomeAvailable: Time + ) => { + const formattedExpectedTime = formatTimeToUIString( + expectedTimeToBecomeAvailable + ); + const becameUnavailableMessage = `\`${serviceName}\` became unavailable for this generation.`; + const tryAgainMessage = `If you want to use it, try again in ~ ${formattedExpectedTime}. Caused by error: "${errorMessage}".`; + return `${becameUnavailableMessage} ${tryAgainMessage}`; + }; + + export const failedToReachRemoteService = ( + serviceName: string, + message: string + ) => { + const serviceFailureMessage = `\`${serviceName}\` became unavailable for this generation: ${message}.`; + const tryAgainMessage = `Check your internet connection and try again.`; + return `${serviceFailureMessage} ${tryAgainMessage}`; + }; + + export const serviceIsAvailableAgain = (serviceName: string) => + `\`${serviceName}\` is available again!`; + + export const modelConfiguredIncorrectly = ( + modelId: string, + errorMessage: string + ) => + `Model "${modelId}" is configured incorrectly: ${errorMessage}. Thus, "${modelId}" will be skipped for this run. Please fix the model's configuration in the settings.`; + + export const unknownContextTheoremsRanker = `Please select one of the existing theorems-ranker types: "distance" or "random".`; + + export const unableToValidateUserSettings = ( + settingsName: string, + validationErrors: DefinedError[], + ignoreErrorsWithKeywords: string[] + ) => + `Unable to validate settings for \`${settingsName}\`: ${ajvErrorsAsString(validationErrors, ignoreErrorsWithKeywords)}. Please fix the configuration in the settings.`; + + export const modelsIdsAreNotUnique = (modelId: string) => + `Please make identifiers of the models unique ("${modelId}" is not unique).`; + + export const apiKeyIsNotSet = (serviceName: string) => + `Please set your ${serviceName} API key in the settings.`; + + export const noValidModelsAreChosen = + "No valid models are chosen. Please specify at least one in the settings."; + + export const userValueWasOverriden = ( + modelId: string, + paramName: string, + withValue: any, + explanationMessage?: string + ) => { + const explanation = + explanationMessage === undefined ? "" : `: ${explanationMessage}`; + return `The \`${paramName}\` parameter of the "${modelId}" model was overriden with the value ${stringifyAnyValue(withValue)}${explanation}. Please configure it the same way in the settings.`; + }; } -export function showMessageToUser( +export type UIMessageSeverity = "error" | "info" | "warning"; + +export function showMessageToUser( message: string, - severity: "error" | "info" | "warning" = "info" -) { + severity: UIMessageSeverity = "info", + ...items: T[] +): Thenable { switch (severity) { case "error": - window.showErrorMessage(message); - break; + return window.showErrorMessage(message, ...items); case "info": - window.showInformationMessage(message); - break; + return window.showInformationMessage(message, ...items); case "warning": - window.showWarningMessage(message); - break; + return window.showWarningMessage(message, ...items); } } -export function showApiKeyNotProvidedMessage( - service: "openai" | "grazie", - pluginId: string -) { - const serviceParamSettingName = - service === "openai" - ? "openAiModelsParameters" - : "grazieModelsParameters"; - const serviceName = service === "openai" ? "Open Ai" : "Grazie"; - - window - .showInformationMessage( - `Please set your ${serviceName} API key in the settings.`, - "Open settings" - ) - .then((value) => { - if (value === "Open settings") { - commands.executeCommand( - "workbench.action.openSettings", - `${pluginId}.${serviceParamSettingName}` - ); - } - }); -} +function formatTimeToUIString(time: Time): string { + const orderedTimeItems: [number, string][] = [ + [time.days, "day"], + [time.hours, "hour"], + [time.minutes, "minute"], + [time.seconds, "second"], + ].map(([value, name]) => [ + value as number, + formatTimeItem(value as number, name as string), + ]); + const itemsN = orderedTimeItems.length; -export async function suggestAddingAuxFilesToGitignore() { - const workspaceFolders = workspace.workspaceFolders; - if (!workspaceFolders) { - return; - } - - for (const folder of workspaceFolders) { - const gitIgnorePath = path.join(folder.uri.fsPath, ".gitignore"); - if (!existsSync(gitIgnorePath)) { - // .gitignore not found. Exit. - return; - } - - const data = readFileSync(gitIgnorePath, "utf8"); - const auxExt = "*_cp_aux.v"; - if (data.indexOf(auxExt) === -1) { - // Not found. Ask user if we should add it. - await window - .showInformationMessage( - 'Do you want to add "*_cp_aux.v" to .gitignore?', - "Yes", - "No" - ) - .then((choice) => { - if (choice === "Yes") { - const rule = `\n# Coqpilot auxiliary files\n${auxExt}`; - appendFile(gitIgnorePath, rule, (err) => { - if (err) { - showMessageToUser( - `Unexpected error writing to .gitignore: ${err.message}`, - "error" - ); - } else { - showMessageToUser( - 'Successfully added "*_cp_aux.v" to .gitignore' - ); - } - }); - } - }); - } else { - return; + for (let i = 0; i < itemsN; i++) { + const [value, formattedItem] = orderedTimeItems[i]; + if (value !== 0) { + const nextFormattedItem = + i === itemsN - 1 ? "" : `, ${orderedTimeItems[i + 1][1]}`; + return `${formattedItem}${nextFormattedItem}`; } } + const zeroSeconds = orderedTimeItems[3][1]; + return `${zeroSeconds}`; +} + +function formatTimeItem(value: number, name: string): string { + const suffix = value === 1 ? "" : "s"; + return `${value} ${name}${suffix}`; } diff --git a/src/extension/globalExtensionState.ts b/src/extension/globalExtensionState.ts new file mode 100644 index 00000000..a41931fd --- /dev/null +++ b/src/extension/globalExtensionState.ts @@ -0,0 +1,69 @@ +import * as fs from "fs"; +import * as path from "path"; +import * as tmp from "tmp"; +import { WorkspaceConfiguration, workspace } from "vscode"; + +import { LLMServices, disposeServices } from "../llm/llmServices"; +import { GrazieService } from "../llm/llmServices/grazie/grazieService"; +import { LMStudioService } from "../llm/llmServices/lmStudio/lmStudioService"; +import { OpenAiService } from "../llm/llmServices/openai/openAiService"; +import { PredefinedProofsService } from "../llm/llmServices/predefinedProofs/predefinedProofsService"; + +import { EventLogger, Severity } from "../logging/eventLogger"; + +import { pluginId } from "./coqPilot"; +import VSCodeLogWriter from "./vscodeLogWriter"; + +export class GlobalExtensionState { + public readonly eventLogger: EventLogger = new EventLogger(); + public readonly logWriter: VSCodeLogWriter = new VSCodeLogWriter( + this.eventLogger, + this.parseLoggingVerbosity(workspace.getConfiguration(pluginId)) + ); + + public readonly llmServicesLogsDir = path.join( + tmp.dirSync().name, + "llm-services-logs" + ); + + public readonly llmServices: LLMServices = { + predefinedProofsService: new PredefinedProofsService( + this.eventLogger, + false, + path.join(this.llmServicesLogsDir, "predefined-proofs-logs.txt") + ), + openAiService: new OpenAiService( + this.eventLogger, + false, + path.join(this.llmServicesLogsDir, "openai-logs.txt") + ), + grazieService: new GrazieService( + this.eventLogger, + false, + path.join(this.llmServicesLogsDir, "grazie-logs.txt") + ), + lmStudioService: new LMStudioService( + this.eventLogger, + false, + path.join(this.llmServicesLogsDir, "lmstudio-logs.txt") + ), + }; + + private parseLoggingVerbosity(config: WorkspaceConfiguration): Severity { + const verbosity = config.get("loggingVerbosity"); + switch (verbosity) { + case "info": + return Severity.INFO; + case "debug": + return Severity.DEBUG; + default: + throw Error(`unknown logging verbosity: ${verbosity}`); + } + } + + dispose(): void { + disposeServices(this.llmServices); + this.logWriter.dispose(); + fs.rmSync(this.llmServicesLogsDir, { recursive: true, force: true }); + } +} diff --git a/src/extension/llmServicesEventsHandler.ts b/src/extension/llmServicesEventsHandler.ts new file mode 100644 index 00000000..04f7134a --- /dev/null +++ b/src/extension/llmServicesEventsHandler.ts @@ -0,0 +1,216 @@ +import { + ConfigurationError, + GenerationFailedError, + RemoteConnectionError, +} from "../llm/llmServiceErrors"; +import { LLMServices, asLLMServices } from "../llm/llmServices"; +import { + LLMServiceImpl, + LLMServiceRequest, + LLMServiceRequestFailed, + LLMServiceRequestSucceeded, +} from "../llm/llmServices/llmService"; +import { ModelParams } from "../llm/llmServices/modelParams"; + +import { EventLogger } from "../logging/eventLogger"; +import { stringifyAnyValue } from "../utils/printers"; +import { SimpleSet } from "../utils/simpleSet"; + +import { EditorMessages, showMessageToUser } from "./editorMessages"; +import { + showMessageToUserWithSettingsHint, + toSettingName, +} from "./settingsValidationError"; + +enum LLMServiceAvailablityState { + AVAILABLE, + UNAVAILABLE, +} + +enum LLMServiceMessagesShownState { + NO_MESSAGES_SHOWN, + BECOME_UNAVAILABLE_MESSAGE_SHOWN, + AGAIN_AVAILABLE_MESSAGE_SHOWN, +} + +interface LLMServiceUIState { + availabilityState: LLMServiceAvailablityState; + messagesShownState: LLMServiceMessagesShownState; +} + +type LLMServiceToUIState = Map; +type ModelsSet = SimpleSet; + +export type UnsubscribeFromLLMServicesEventsCallback = () => void; + +export function subscribeToHandleLLMServicesEvents( + llmServices: LLMServices, + eventLogger: EventLogger +): UnsubscribeFromLLMServicesEventsCallback { + const llmServiceToUIState = createLLMServiceToUIState(llmServices); + const seenIncorrectlyConfiguredModels: ModelsSet = new SimpleSet( + (model: ModelParams) => model.modelId + ); + + const succeededSubscriptionId = eventLogger.subscribeToLogicEvent( + LLMServiceImpl.requestSucceededEvent, + reactToRequestSucceededEvent(llmServiceToUIState) + ); + const failedSubscriptionId = eventLogger.subscribeToLogicEvent( + LLMServiceImpl.requestFailedEvent, + reactToRequestFailedEvent( + llmServiceToUIState, + seenIncorrectlyConfiguredModels + ) + ); + + return () => { + eventLogger.unsubscribe( + LLMServiceImpl.requestSucceededEvent, + succeededSubscriptionId + ); + eventLogger.unsubscribe( + LLMServiceImpl.requestFailedEvent, + failedSubscriptionId + ); + }; +} + +function createLLMServiceToUIState( + llmServices: LLMServices +): LLMServiceToUIState { + const initialState: LLMServiceUIState = { + availabilityState: LLMServiceAvailablityState.AVAILABLE, + messagesShownState: LLMServiceMessagesShownState.NO_MESSAGES_SHOWN, + }; + return new Map( + asLLMServices(llmServices).map((llmService) => [ + llmService.serviceName, + { + ...initialState, + }, + ]) + ); +} + +function reactToRequestSucceededEvent( + llmServiceToUIState: LLMServiceToUIState +): (data: any) => void { + return (data: any) => { + const [requestSucceeded, uiState] = + parseLLMServiceRequestEvent( + data, + llmServiceToUIState, + `data of the ${LLMServiceImpl.requestSucceededEvent} event should be a \`LLMServiceRequestSucceeded\` object` + ); + if ( + uiState.availabilityState === LLMServiceAvailablityState.UNAVAILABLE + ) { + uiState.availabilityState = LLMServiceAvailablityState.AVAILABLE; + if ( + uiState.messagesShownState === + LLMServiceMessagesShownState.BECOME_UNAVAILABLE_MESSAGE_SHOWN + ) { + showMessageToUser( + EditorMessages.serviceIsAvailableAgain( + requestSucceeded.llmService.serviceName + ), + "info" + ); + uiState.messagesShownState = + LLMServiceMessagesShownState.AGAIN_AVAILABLE_MESSAGE_SHOWN; + } + } + }; +} + +function reactToRequestFailedEvent( + llmServiceToUIState: LLMServiceToUIState, + seenIncorrectlyConfiguredModels: ModelsSet +): (data: any) => void { + return (data: any) => { + const [requestFailed, uiState] = + parseLLMServiceRequestEvent( + data, + llmServiceToUIState, + `data of the ${LLMServiceImpl.requestFailedEvent} event should be a \`LLMServiceRequestFailed\` object` + ); + + const llmServiceError = requestFailed.llmServiceError; + const model = requestFailed.params; + if (llmServiceError instanceof ConfigurationError) { + if (seenIncorrectlyConfiguredModels.has(model)) { + return; // don't show configuration error of the same model to the user twice + } + seenIncorrectlyConfiguredModels.add(model); + showMessageToUserWithSettingsHint( + EditorMessages.modelConfiguredIncorrectly( + model.modelId, + llmServiceError.message + ), + "error", + toSettingName(requestFailed.llmService) + ); + return; + } + if ( + !( + llmServiceError instanceof RemoteConnectionError || + llmServiceError instanceof GenerationFailedError + ) + ) { + throw Error( + `\`llmServiceError\` of the received ${LLMServiceImpl.requestFailedEvent} event data is expected to be either a \` ConfigurationError\`, \`RemoteConnectionError\`, or \`GenerationFailedError\`, but got: "${llmServiceError}"` + ); + } + + if ( + uiState.availabilityState === LLMServiceAvailablityState.AVAILABLE + ) { + uiState.availabilityState = LLMServiceAvailablityState.UNAVAILABLE; + if ( + uiState.messagesShownState === + LLMServiceMessagesShownState.NO_MESSAGES_SHOWN + ) { + const serviceName = requestFailed.llmService.serviceName; + if (llmServiceError instanceof GenerationFailedError) { + showMessageToUser( + EditorMessages.serviceBecameUnavailable( + serviceName, + llmServiceError.cause.message, + requestFailed.llmService.estimateTimeToBecomeAvailable() + ), + "warning" + ); + } else { + showMessageToUser( + EditorMessages.failedToReachRemoteService( + serviceName, + llmServiceError.message + ), + "warning" + ); + } + uiState.messagesShownState = + LLMServiceMessagesShownState.BECOME_UNAVAILABLE_MESSAGE_SHOWN; + } + } + }; +} + +function parseLLMServiceRequestEvent( + data: any, + llmServiceToUIState: LLMServiceToUIState, + errorMessage: string +): [T, LLMServiceUIState] { + const request = data as T; + if (request === null) { + throw Error(`${errorMessage}, but data = ${stringifyAnyValue(data)}`); + } + const serviceName = request.llmService.serviceName; + const uiState = llmServiceToUIState.get(serviceName); + if (uiState === undefined) { + throw Error(`no UI state for \`${serviceName}\``); + } + return [request, uiState]; +} diff --git a/src/extension/settingsValidationError.ts b/src/extension/settingsValidationError.ts new file mode 100644 index 00000000..284c76c7 --- /dev/null +++ b/src/extension/settingsValidationError.ts @@ -0,0 +1,54 @@ +import { commands } from "vscode"; + +import { switchByLLMServiceType } from "../llm/llmServices"; +import { LLMService } from "../llm/llmServices/llmService"; + +import { pluginId } from "./coqPilot"; +import { UIMessageSeverity, showMessageToUser } from "./editorMessages"; + +export const openSettingsItem = "Open settings"; + +export class SettingsValidationError extends Error { + constructor( + errorMessage: string, + private readonly messageToShowToUser: string, + private readonly settingToOpenName: string = pluginId, + private readonly severity: UIMessageSeverity = "error" + ) { + super(errorMessage); + } + + showAsMessageToUser() { + showMessageToUserWithSettingsHint( + this.messageToShowToUser, + this.severity, + this.settingToOpenName + ); + } +} + +export function showMessageToUserWithSettingsHint( + message: string, + severity: UIMessageSeverity, + settingToOpenName: string = pluginId +) { + showMessageToUser(message, severity, openSettingsItem).then((value) => { + if (value === openSettingsItem) { + commands.executeCommand( + "workbench.action.openSettings", + settingToOpenName + ); + } + }); +} + +export function toSettingName(llmService: LLMService): string { + const serviceNameInSettings = switchByLLMServiceType( + llmService, + () => "predefinedProofs", + () => "openAi", + () => "grazie", + () => "lmStudio" + ); + return `${pluginId}.${serviceNameInSettings}ModelsParameters`; +} diff --git a/src/extension/vscodeLogWriter.ts b/src/extension/vscodeLogWriter.ts index e2893a63..644e0e05 100644 --- a/src/extension/vscodeLogWriter.ts +++ b/src/extension/vscodeLogWriter.ts @@ -1,7 +1,7 @@ import pino, { DestinationStream, LoggerOptions } from "pino"; import { OutputChannel, window } from "vscode"; -import { ALL_EVENTS, EventLogger, Severity } from "../logging/eventLogger"; +import { EventLogger, Severity, anyEventKeyword } from "../logging/eventLogger"; class VSCodeLogWriter { private readonly outputStream = new VSCodeOutputChannelDestinationStream( @@ -25,15 +25,19 @@ class VSCodeLogWriter { ); constructor(eventLogger: EventLogger, logLevel: Severity = Severity.INFO) { - eventLogger.subscribe(ALL_EVENTS, Severity.INFO, (message, data) => { - this.outputStreamWriter.info({ - message, - data, - }); - }); + eventLogger.subscribe( + anyEventKeyword, + Severity.INFO, + (message, data) => { + this.outputStreamWriter.info({ + message, + data, + }); + } + ); if (logLevel === Severity.DEBUG) { eventLogger.subscribe( - ALL_EVENTS, + anyEventKeyword, Severity.DEBUG, (message, data) => { this.outputStreamWriter.info({ diff --git a/src/llm/llmIterator.ts b/src/llm/llmIterator.ts index 858288ee..90bc5c89 100644 --- a/src/llm/llmIterator.ts +++ b/src/llm/llmIterator.ts @@ -2,12 +2,8 @@ import { EventLogger } from "../logging/eventLogger"; import { LLMServices } from "./llmServices"; import { GeneratedProof, LLMService } from "./llmServices/llmService"; +import { ModelParams, ModelsParams } from "./llmServices/modelParams"; import { ProofGenerationContext } from "./proofGenerationContext"; -import { - PredefinedProofsUserModelParams, - UserModelParams, - UserModelsParams, -} from "./userModelParams"; type GeneratedProofsBatch = GeneratedProof[]; type ProofsGenerationHook = () => Promise; @@ -21,11 +17,9 @@ export class LLMSequentialIterator private hooksIndex: number; private insideBatchIndex: number; - private readonly defaultGenerationChoices = 10; - constructor( proofGenerationContext: ProofGenerationContext, - modelsParams: UserModelsParams, + modelsParams: ModelsParams, services: LLMServices, private eventLogger?: EventLogger ) { @@ -43,7 +37,7 @@ export class LLMSequentialIterator private createHooks( proofGenerationContext: ProofGenerationContext, - modelsParams: UserModelsParams, + modelsParams: ModelsParams, services: LLMServices ): ProofsGenerationHook[] { return [ @@ -51,9 +45,7 @@ export class LLMSequentialIterator proofGenerationContext, modelsParams.predefinedProofsModelParams, services.predefinedProofsService, - "predefined-proofs", - (params) => - (params as PredefinedProofsUserModelParams).tactics.length + "predefined-proofs" ), ...this.createLLMServiceHooks( proofGenerationContext, @@ -76,28 +68,23 @@ export class LLMSequentialIterator ]; } - private createLLMServiceHooks( + private createLLMServiceHooks( proofGenerationContext: ProofGenerationContext, - allModelParamsForService: UserModelParams[], - llmService: LLMService, - serviceLoggingName: string, - resolveChoices: (userModelParams: UserModelParams) => number = (_) => - this.defaultGenerationChoices + allModelParamsForService: ResolvedModelParams[], + llmService: LLMService, + serviceLoggingName: string ): ProofsGenerationHook[] { const hooks = []; - for (const userModelParams of allModelParamsForService) { + for (const modelParams of allModelParamsForService) { hooks.push(() => { - const resolvedParams = - llmService.resolveParameters(userModelParams); this.eventLogger?.log( `${serviceLoggingName}-fetch-started`, `Completion from ${serviceLoggingName}`, - resolvedParams + modelParams ); return llmService.generateProof( proofGenerationContext, - resolvedParams, - userModelParams.choices ?? resolveChoices(userModelParams) + modelParams ); }); } @@ -129,7 +116,7 @@ export class LLMSequentialIterator return false; } - async next(): Promise> { + async next(): Promise> { const finished = await this.prepareFetched(); if (finished) { return { done: true, value: undefined }; @@ -143,7 +130,7 @@ export class LLMSequentialIterator return { done: false, value: proofs }; } - async nextProof(): Promise> { + async nextProof(): Promise> { const finished = await this.prepareFetched(); if (finished) { return { done: true, value: undefined }; diff --git a/src/llm/llmServiceErrors.ts b/src/llm/llmServiceErrors.ts new file mode 100644 index 00000000..090a3e02 --- /dev/null +++ b/src/llm/llmServiceErrors.ts @@ -0,0 +1,50 @@ +/** + * Base class for the errors thrown by `LLMService`. + */ +export abstract class LLMServiceError extends Error { + constructor( + message: string = "", + readonly cause: Error | undefined = undefined + ) { + let errorMessage = message; + if (cause !== undefined) { + const causeMessage = `cause: [${cause.name}] "${cause.message}"`; + errorMessage = + message === "" ? causeMessage : `${message}, ${causeMessage}`; + } + super(errorMessage); + } +} + +/** + * Represents the failure of the generation request caused by invalid parameters + * configured by the user or the plugin. + */ +export class ConfigurationError extends LLMServiceError { + constructor(message: string) { + super(message); + } +} + +/** + * Represents the failure of the generation request caused by inability + * to reach a remote service or a remote resource. + * + * This error is not of `GenerationFailedError` type, because the actual proof-generation process + * has not yet trully started and the problems are most likely on the user side. + */ +export class RemoteConnectionError extends LLMServiceError { + constructor(message: string) { + super(message); + } +} + +/** + * Represents the failure of the actual proof-generation process, + * i.e. after all parameters validation has been performed. + */ +export class GenerationFailedError extends LLMServiceError { + constructor(readonly cause: Error) { + super("", cause); + } +} diff --git a/src/llm/llmServices.ts b/src/llm/llmServices.ts index 126e98bd..18c9bd9e 100644 --- a/src/llm/llmServices.ts +++ b/src/llm/llmServices.ts @@ -1,11 +1,51 @@ import { GrazieService } from "./llmServices/grazie/grazieService"; +import { LLMService } from "./llmServices/llmService"; import { LMStudioService } from "./llmServices/lmStudio/lmStudioService"; +import { ModelParams } from "./llmServices/modelParams"; import { OpenAiService } from "./llmServices/openai/openAiService"; import { PredefinedProofsService } from "./llmServices/predefinedProofs/predefinedProofsService"; +import { UserModelParams } from "./userModelParams"; export interface LLMServices { + predefinedProofsService: PredefinedProofsService; openAiService: OpenAiService; grazieService: GrazieService; - predefinedProofsService: PredefinedProofsService; lmStudioService: LMStudioService; } + +export function disposeServices(llmServices: LLMServices) { + asLLMServices(llmServices).forEach((service) => service.dispose()); +} + +export function asLLMServices( + llmServices: LLMServices +): LLMService[] { + return [ + llmServices.predefinedProofsService, + llmServices.openAiService, + llmServices.grazieService, + llmServices.lmStudioService, + ]; +} + +export function switchByLLMServiceType( + llmService: LLMService, + onPredefinedProofsService: () => T, + onOpenAiService: () => T, + onGrazieService: () => T, + onLMStudioService: () => T +): T { + if (llmService instanceof PredefinedProofsService) { + return onPredefinedProofsService(); + } else if (llmService instanceof OpenAiService) { + return onOpenAiService(); + } else if (llmService instanceof GrazieService) { + return onGrazieService(); + } else if (llmService instanceof LMStudioService) { + return onLMStudioService(); + } else { + throw Error( + `switch by unknown LLMService: "${llmService.serviceName}"` + ); + } +} diff --git a/src/llm/llmServices/chat.ts b/src/llm/llmServices/chat.ts index 6d96d6d8..0eb78ff0 100644 --- a/src/llm/llmServices/chat.ts +++ b/src/llm/llmServices/chat.ts @@ -3,3 +3,14 @@ export type ChatRole = "system" | "user" | "assistant"; export type ChatMessage = { role: ChatRole; content: string }; export type ChatHistory = ChatMessage[]; + +export interface AnalyzedChatHistory { + chat: ChatHistory; + estimatedTokens?: EstimatedTokens; +} + +export interface EstimatedTokens { + messagesTokens: number; + maxTokensToGenerate: number; + maxTokensInTotal: number; +} diff --git a/src/llm/llmServices/grazie/grazieApi.ts b/src/llm/llmServices/grazie/grazieApi.ts index 4fa586aa..1e69f1cd 100644 --- a/src/llm/llmServices/grazie/grazieApi.ts +++ b/src/llm/llmServices/grazie/grazieApi.ts @@ -1,8 +1,7 @@ -/* eslint-disable @typescript-eslint/naming-convention */ import axios from "axios"; import { ResponseType } from "axios"; -import { EventLogger, Severity } from "../../../logging/eventLogger"; +import { DebugWrappers } from "../llmServiceInternal"; import { GrazieModelParams } from "../modelParams"; export type GrazieChatRole = "User" | "System" | "Assistant"; @@ -14,15 +13,12 @@ interface GrazieConfig { } export class GrazieApi { - private readonly config: GrazieConfig; + private readonly config: GrazieConfig = { + chatUrl: "v5/llm/chat/stream/v3", + gateawayUrl: "https://api.app.stgn.grazie.aws.intellij.net/service/", + }; - constructor(private readonly eventLogger?: EventLogger) { - this.config = { - chatUrl: "v5/llm/chat/stream/v3", - gateawayUrl: - "https://api.app.stgn.grazie.aws.intellij.net/service/", - }; - } + constructor(private readonly debug: DebugWrappers) {} async requestChatCompletion( params: GrazieModelParams, @@ -50,16 +46,13 @@ export class GrazieApi { apiToken: string ): Promise { const headers = this.createHeaders(apiToken); - this.eventLogger?.log( - "grazie-fetch-started", - "Completion from Grazie requested", - { - url: url, - body: body, - headers: headers, - }, - Severity.DEBUG - ); + headers["Content-Length"] = body.length; + + this.debug.logEvent("Completion requested", { + url: url, + body: body, + headers: headers, + }); const response = await this.fetchAndProcessEvents( this.config.gateawayUrl + url, @@ -71,12 +64,14 @@ export class GrazieApi { } private createHeaders(token: string): any { + /* eslint-disable @typescript-eslint/naming-convention */ return { Accept: "*/*", "Content-Type": "application/json", "Grazie-Authenticate-Jwt": token, "Grazie-Original-Service-JWT": token, }; + /* eslint-enable @typescript-eslint/naming-convention */ } private chunkToTokens(chunk: any): string[] { @@ -99,7 +94,7 @@ export class GrazieApi { const messageData = JSON.parse(validJSON); messages.push(messageData.current); } else { - throw new Error( + throw Error( "Unexpected chunk: " + tokenWrapped + ". Please report this error." diff --git a/src/llm/llmServices/grazie/grazieModelParamsResolver.ts b/src/llm/llmServices/grazie/grazieModelParamsResolver.ts new file mode 100644 index 00000000..3841ab0f --- /dev/null +++ b/src/llm/llmServices/grazie/grazieModelParamsResolver.ts @@ -0,0 +1,35 @@ +import { GrazieUserModelParams } from "../../userModelParams"; +import { GrazieModelParams, grazieModelParamsSchema } from "../modelParams"; +import { BasicModelParamsResolver } from "../utils/paramsResolvers/basicModelParamsResolvers"; +import { ValidationRules } from "../utils/paramsResolvers/builders"; +import { ValidParamsResolverImpl } from "../utils/paramsResolvers/paramsResolverImpl"; + +import { GrazieService } from "./grazieService"; + +export class GrazieModelParamsResolver + extends BasicModelParamsResolver + implements + ValidParamsResolverImpl +{ + constructor() { + super(grazieModelParamsSchema, "GrazieModelParams"); + } + + readonly modelName = this.resolveParam("modelName") + .requiredToBeConfigured() + .validateAtRuntimeOnly(); + + readonly apiKey = this.resolveParam("apiKey") + .requiredToBeConfigured() + .validateAtRuntimeOnly(); + + readonly maxTokensToGenerate = this.resolveParam( + "maxTokensToGenerate" + ) + .override( + () => GrazieService.maxTokensToGeneratePredefined, + `is always ${GrazieService.maxTokensToGeneratePredefined} for \`GrazieService\`` + ) + .requiredToBeConfigured() + .validate(ValidationRules.bePositiveNumber); +} diff --git a/src/llm/llmServices/grazie/grazieService.ts b/src/llm/llmServices/grazie/grazieService.ts index e4a255b1..d988096c 100644 --- a/src/llm/llmServices/grazie/grazieService.ts +++ b/src/llm/llmServices/grazie/grazieService.ts @@ -1,56 +1,103 @@ import { EventLogger } from "../../../logging/eventLogger"; import { ProofGenerationContext } from "../../proofGenerationContext"; -import { UserModelParams } from "../../userModelParams"; +import { GrazieUserModelParams } from "../../userModelParams"; import { ChatHistory, ChatMessage } from "../chat"; -import { GeneratedProof, Proof, ProofVersion } from "../llmService"; -import { LLMService } from "../llmService"; -import { GrazieModelParams, ModelParams } from "../modelParams"; +import { GeneratedProofImpl, ProofVersion } from "../llmService"; +import { LLMServiceImpl } from "../llmService"; +import { LLMServiceInternal } from "../llmServiceInternal"; +import { GrazieModelParams } from "../modelParams"; import { GrazieApi, GrazieChatRole, GrazieFormattedHistory } from "./grazieApi"; +import { GrazieModelParamsResolver } from "./grazieModelParamsResolver"; -export class GrazieService extends LLMService { - private api: GrazieApi; - // Is constant (now) as specified in Grazie REST API - private readonly newMessageMaxTokens = 1024; +export class GrazieService extends LLMServiceImpl< + GrazieUserModelParams, + GrazieModelParams, + GrazieService, + GrazieGeneratedProof, + GrazieServiceInternal +> { + protected readonly internal: GrazieServiceInternal; + protected readonly modelParamsResolver = new GrazieModelParamsResolver(); - constructor(eventLogger?: EventLogger) { - super(eventLogger); - this.api = new GrazieApi(eventLogger); + constructor( + eventLogger?: EventLogger, + debugLogs: boolean = false, + generationsLogsFilePath?: string + ) { + super("GrazieService", eventLogger, debugLogs, generationsLogsFilePath); + this.internal = new GrazieServiceInternal( + this, + this.eventLoggerGetter, + this.generationsLoggerBuilder + ); } + /** + * As specified in Grazie REST API, `maxTokensToGenerate` is a constant currently. + */ + static readonly maxTokensToGeneratePredefined = 1024; +} + +export class GrazieGeneratedProof extends GeneratedProofImpl< + GrazieModelParams, + GrazieService, + GrazieGeneratedProof, + GrazieServiceInternal +> { + constructor( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: GrazieModelParams, + llmServiceInternal: GrazieServiceInternal, + previousProofVersions?: ProofVersion[] + ) { + super( + proof, + proofGenerationContext, + modelParams, + llmServiceInternal, + previousProofVersions + ); + } +} + +class GrazieServiceInternal extends LLMServiceInternal< + GrazieModelParams, + GrazieService, + GrazieGeneratedProof, + GrazieServiceInternal +> { + readonly api = new GrazieApi(this.debug); + constructGeneratedProof( proof: string, proofGenerationContext: ProofGenerationContext, - modelParams: ModelParams, + modelParams: GrazieModelParams, previousProofVersions?: ProofVersion[] | undefined - ): GeneratedProof { + ): GrazieGeneratedProof { return new GrazieGeneratedProof( proof, proofGenerationContext, - modelParams as GrazieModelParams, + modelParams, this, previousProofVersions ); } - async generateFromChat( + async generateFromChatImpl( chat: ChatHistory, - params: ModelParams, + params: GrazieModelParams, choices: number ): Promise { - if (choices <= 0) { - return []; - } + this.validateChoices(choices); let attempts = choices * 2; const completions: Promise[] = []; const formattedChat = this.formatChatHistory(chat); while (completions.length < choices && attempts > 0) { completions.push( - this.api.requestChatCompletion( - params as GrazieModelParams, - formattedChat - ) + this.api.requestChatCompletion(params, formattedChat) ); attempts--; } @@ -68,27 +115,4 @@ export class GrazieService extends LLMService { }; }); } - - resolveParameters(params: UserModelParams): ModelParams { - params.newMessageMaxTokens = this.newMessageMaxTokens; - return this.resolveParametersWithDefaults(params); - } -} - -export class GrazieGeneratedProof extends GeneratedProof { - constructor( - proof: Proof, - proofGenerationContext: ProofGenerationContext, - modelParams: GrazieModelParams, - llmService: GrazieService, - previousProofVersions?: ProofVersion[] - ) { - super( - proof, - proofGenerationContext, - modelParams, - llmService, - previousProofVersions - ); - } } diff --git a/src/llm/llmServices/llmService.ts b/src/llm/llmServices/llmService.ts index 59366fff..b5f9c7b3 100644 --- a/src/llm/llmServices/llmService.ts +++ b/src/llm/llmServices/llmService.ts @@ -1,239 +1,473 @@ -import * as assert from "assert"; +import * as tmp from "tmp"; import { EventLogger } from "../../logging/eventLogger"; +import { ConfigurationError, LLMServiceError } from "../llmServiceErrors"; import { ProofGenerationContext } from "../proofGenerationContext"; import { UserModelParams } from "../userModelParams"; -import { ChatHistory } from "./chat"; -import { ModelParams, MultiroundProfile } from "./modelParams"; +import { AnalyzedChatHistory } from "./chat"; +import { LLMServiceInternal } from "./llmServiceInternal"; +import { ModelParams } from "./modelParams"; import { buildProofFixChat, buildProofGenerationChat, } from "./utils/chatFactory"; - -export type Proof = string; +import { estimateTimeToBecomeAvailableDefault } from "./utils/defaultAvailabilityEstimator"; +import { GenerationsLogger } from "./utils/generationsLogger/generationsLogger"; +import { LoggerRecord } from "./utils/generationsLogger/loggerRecord"; +import { + ParamsResolutionResult, + ParamsResolver, +} from "./utils/paramsResolvers/abstractResolvers"; +import { Time } from "./utils/time"; export interface ProofVersion { - proof: Proof; + proof: string; diagnostic?: string; } -export abstract class LLMService { - constructor(protected readonly eventLogger?: EventLogger) {} +export enum ErrorsHandlingMode { + LOG_EVENTS_AND_SWALLOW_ERRORS = "log events & swallow errors", + RETHROW_ERRORS = "rethrow errors", +} - abstract constructGeneratedProof( - proof: Proof, - proofGenerationContext: ProofGenerationContext, - modelParams: ModelParams, - previousProofVersions?: ProofVersion[] - ): GeneratedProof; +/** + * Interface for `LLMServiceImpl` to package all generation request data. + * Then, this data is used for interaction between implementation components. + * In addition, interfaces derived from it can be passed to loggers to record the requests' results. + */ +export interface LLMServiceRequest { + llmService: LLMService; + params: ModelParams; + choices: number; + analyzedChat?: AnalyzedChatHistory; +} + +export interface LLMServiceRequestSucceeded extends LLMServiceRequest { + generatedRawProofs: string[]; +} + +export interface LLMServiceRequestFailed extends LLMServiceRequest { + llmServiceError: LLMServiceError; +} + +/** + * Facade type for the `LLMServiceImpl` type. + * + * The proper typing of self `LLMServiceImpl`, returning `GeneratedProofImpl`-s and `LLMServiceImpl.internal` + * is required inside implementation only. + * Thus, `LLMServiceImpl` should be resolved with `any` for the implementation generic types, when used outside. + */ +export type LLMService< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, +> = LLMServiceImpl; + +/** + * `LLMServiceImpl` represents a service for proofs generation. + * Proofs can be generated from both `ProofGenerationContext` and `AnalyzedChatHistory`. + * Generated proofs are represented by `GeneratedProofImpl` class and + * can be further regenerated (fixed / shortened / etc), also keeping their previous versions. + * + * 1. All model parameters of the `ResolvedModelParams` type accepted by `LLMService`-related methods + * are expected to be resolved by `resolveParameters` method beforehand. + * This method resolves partially-undefined `InputModelParams` to complete and validated `ResolvedModelParams`. + * See the `resolveParameters` method for more details. + * + * 2. All proofs-generation methods support errors handling and logging. + * - Each successfull generation is logged both by `GenerationsLogger` and `EventLogger`. + * - If error occurs, it is catched and then: + * - is wrapped into `LLMServiceError` and then... + * - in case of `LOG_EVENTS_AND_SWALLOW_ERRORS`, it's only logged by `EventLogger`; + * - in case of `RETHROW_ERRORS`, it's rethrown. + * + * `EventLogger` sends `requestSucceededEvent` and `requestFailedEvent` + * (along with `LLMServiceRequest` as data), which can be handled then, for example, by the UI. + * + * Regardless errors handling modes and `EventLogger` behaviour, + * `GenerationsLogger` maintains the logs of both successful and failed generations + * used for the further estimation of the service availability. See the `estimateTimeToBecomeAvailable` method. + * + * 3. To implement a new `LLMServiceImpl` based on generating proofs from chats, one should: + * - declare the specification of models parameters via custom `UserModelParams` and `ModelParams` interfaces; + * - implement custom `ParamsResolver` class, declaring the algorithm to resolve parameters with; + * - declare custom `GeneratedProofImpl`; + * - implement custom `LLMServiceInternal`; + * - finally, declare custom `LLMServiceImpl`. + * + * I.e. `LLMServiceInternal` is effectively the only class needed to be actually implemented. + * + * If proofs-generation is not supposed to be based on chats, + * the methods of `LLMServiceImpl` should be overriden directly too. + * + * Also, do not be afraid of the complicated generic types in the base classes below. + * Although they look overly complex, they provide great typing support during implementation. + * Just remember to replace all generic types with your specific custom classes whenever possible. + * For example: + * ``` + * class MyLLMService extends LLMServiceImpl< + * MyUserModelParams, + * MyModelParams, + * MyLLMService, + * MyGeneratedProof, + * MyLLMServiceInternal + * > { + * // implementation + * } + * ``` + */ +export abstract class LLMServiceImpl< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, + LLMServiceType extends LLMServiceImpl< + UserModelParams, + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, + GeneratedProofType extends GeneratedProofImpl< + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, + LLMServiceInternalType extends LLMServiceInternal< + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, +> { + protected abstract readonly internal: LLMServiceInternalType; + protected abstract readonly modelParamsResolver: ParamsResolver< + InputModelParams, + ResolvedModelParams + >; + protected readonly eventLoggerGetter: () => EventLogger | undefined; + protected readonly generationsLoggerBuilder: () => GenerationsLogger; - abstract generateFromChat( - chat: ChatHistory, - params: ModelParams, - choices: number - ): Promise; + /** + * Creates an instance of `LLMServiceImpl`. + * @param eventLogger if it is not specified, events won't be logged and passing `LOG_EVENTS_AND_SWALLOW_ERRORS` will throw an error. + * @param debugLogs enables debug logs for the internal `GenerationsLogger`. + * @param generationLogsFilePath if it is not specified, a temporary file will be used. + */ + constructor( + readonly serviceName: string, + eventLogger: EventLogger | undefined, + debugLogs: boolean, + generationLogsFilePath: string | undefined + ) { + this.eventLoggerGetter = () => eventLogger; + this.generationsLoggerBuilder = () => + new GenerationsLogger( + generationLogsFilePath ?? tmp.fileSync().name, + { + debug: debugLogs, + paramsPropertiesToCensor: { + apiKey: GenerationsLogger.censorString, + }, + cleanLogsOnStart: true, + } + ); + } + + static readonly requestSucceededEvent = `llmservice-request-succeeded`; + static readonly requestFailedEvent = `llmservice-request-failed`; + + /** + * Generates proofs from chat. + * This method performs errors-handling and logging, check `LLMServiceImpl` docs for more details. + * + * The default implementation is based on the `LLMServiceInternal.generateFromChatImpl`. + * If it is not the desired way, `generateFromChat` should be overriden. + * + * @param choices if specified, overrides `ModelParams.defaultChoices`. + * @returns generated proofs as raw strings. + */ + async generateFromChat( + analyzedChat: AnalyzedChatHistory, + params: ResolvedModelParams, + choices: number = params.defaultChoices, + errorsHandlingMode: ErrorsHandlingMode = ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ): Promise { + return this.internal.generateFromChatWrapped( + params, + choices, + errorsHandlingMode, + () => analyzedChat, + (proof) => proof + ); + } + /** + * Generates proofs from `ProofGenerationContext`, i.e. from `completionTarget` and `contextTheorems`. + * This method performs errors-handling and logging, check `LLMServiceImpl` docs for more details. + * + * The default implementation is based on the generation from chat, namely, + * it calls `LLMServiceInternal.generateFromChatImpl`. + * If it is not the desired way, `generateProof` should be overriden. + * + * @param choices if specified, overrides `ModelParams.defaultChoices`. + * @returns generated proofs as `GeneratedProofImpl`-s. + */ async generateProof( proofGenerationContext: ProofGenerationContext, - params: ModelParams, - choices: number - ): Promise { - if (choices <= 0) { - return []; - } - const chat = buildProofGenerationChat(proofGenerationContext, params); - const proofs = await this.generateFromChat(chat, params, choices); - return proofs.map((proof: string) => - this.constructGeneratedProof(proof, proofGenerationContext, params) + params: ResolvedModelParams, + choices: number = params.defaultChoices, + errorsHandlingMode: ErrorsHandlingMode = ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ): Promise { + return this.internal.generateFromChatWrapped( + params, + choices, + errorsHandlingMode, + () => buildProofGenerationChat(proofGenerationContext, params), + (proof) => + this.internal.constructGeneratedProof( + proof, + proofGenerationContext, + params + ) ); } - dispose(): void {} + /** + * Estimates the expected time for service to become available. + * To do this, analyzes the logs from `this.generationsLogger` and computes the time. + */ + estimateTimeToBecomeAvailable(): Time { + return estimateTimeToBecomeAvailableDefault( + this.internal.generationsLogger.readLogsSinceLastSuccess() + ); + } - resolveParameters(params: UserModelParams): ModelParams { - return this.resolveParametersWithDefaults(params); + /** + * Reads logs provided by `GenerationsLogger` for this service. + */ + readGenerationsLogs(sinceLastSuccess: boolean = false): LoggerRecord[] { + return sinceLastSuccess + ? this.internal.generationsLogger.readLogsSinceLastSuccess() + : this.internal.generationsLogger.readLogs(); } - protected readonly resolveParametersWithDefaults = ( - params: UserModelParams - ): ModelParams => { - const newMessageMaxTokens = - params.newMessageMaxTokens ?? - this.defaultNewMessageMaxTokens[params.modelName]; - const tokensLimits = - params.tokensLimit ?? this.defaultTokensLimits[params.modelName]; - const systemMessageContent = - params.systemPrompt ?? this.defaultSystemMessageContent; - const multiroundProfile: MultiroundProfile = { - maxRoundsNumber: - params.multiroundProfile?.maxRoundsNumber ?? - this.defaultMultiroundProfile.maxRoundsNumber, - proofFixChoices: - params.multiroundProfile?.proofFixChoices ?? - this.defaultMultiroundProfile.proofFixChoices, - proofFixPrompt: - params.multiroundProfile?.proofFixPrompt ?? - this.defaultMultiroundProfile.proofFixPrompt, - }; - if (newMessageMaxTokens === undefined || tokensLimits === undefined) { - throw Error(`user model parameters cannot be resolved: ${params}`); - } + dispose(): void { + this.internal.dispose(); + } - /** NOTE: it's important to pass `...extractedParams` first - * because if so, then the omitted fields of the `params` - * (`systemPromt`, `newMessageMaxTokens`, `tokensLimit`, etc) - * will be overriden - and not in the opposite way! - */ - return { - ...params, - systemPrompt: systemMessageContent, - newMessageMaxTokens: newMessageMaxTokens, - tokensLimit: tokensLimits, - multiroundProfile: multiroundProfile, - }; - }; - - private readonly defaultNewMessageMaxTokens: { - [modelName: string]: number; - } = {}; - - private readonly defaultTokensLimits: { - [modelName: string]: number; - } = { - // eslint-disable-next-line @typescript-eslint/naming-convention - "gpt-3.5-turbo-0301": 2000, - }; - - private readonly defaultSystemMessageContent: string = - "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'."; - - // its properties can be used separately - private readonly defaultMultiroundProfile: MultiroundProfile = { - maxRoundsNumber: 1, // multiround is disabled by default - proofFixChoices: 1, // 1 fix version per proof by default - proofFixPrompt: - "Unfortunately, the last proof is not correct. Here is the compiler's feedback: '${diagnostic}'. Please, fix the proof.", - }; + /** + * Resolves possibly-incomplete `UserModelParams` to complete `ModelParams`. + * Resolution process includes overrides of input parameters, + * their resolution with default values if needed, and validation of their result values. + * See the `ParamsResolver` class for more details. + * + * This method does not throw. Instead, it always returns resolution logs, which include + * all information about the actions taken on the input parameters and their validation status. + * + * @param params possibly-incomplete parameters configured by user. + * @returns complete and validated parameters for the further generation pipeline. + */ + resolveParameters( + params: InputModelParams + ): ParamsResolutionResult { + return this.modelParamsResolver.resolve(params); + } } -export abstract class GeneratedProof { - readonly llmService: LLMService; - readonly modelParams: ModelParams; +/** + * Facade type for the `GeneratedProofImpl` type. + * + * Most often, the proper typing of `GeneratedProofImpl.modelParams` is not required, + * while the proper typing of the parent `LLMServiceImpl`, returning `GeneratedProofImpl`-s and `GeneratedProofImpl.llmServiceInternal` + * is required inside implementation only. + * Thus, outside of the internal implementation, this class will most likely be parameterized with base classes and any-s. + */ +export type GeneratedProof = GeneratedProofImpl< + ModelParams, + LLMService, + GeneratedProof, + any +>; + +/** + * This class represents a proof generated by `LLMServiceImpl`. + * It stores all the meta information of its generation. + * + * Moreover, it might support multiround generation: fixing, shortening, etc. + * For this, a new version of this proof could be generated via `LLMServiceInternal.generateFromChat`. + * + * Multiround-generation parameters are specified at `ModelParams.multiroundProfile`. + * + * Same to `LLMServiceImpl`, multiround-generation methods perform errors handling and logging (in the same way). + * Same to `LLMServiceImpl`, these methods could be overriden to change the behaviour (of the multiround generation). + * + * Finally, `GeneratedProofImpl` keeps the previous proof versions (but not the future ones). + */ +export abstract class GeneratedProofImpl< + ResolvedModelParams extends ModelParams, + LLMServiceType extends LLMServiceImpl< + UserModelParams, + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, + GeneratedProofType extends GeneratedProofImpl< + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, + LLMServiceInternalType extends LLMServiceInternal< + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, +> { + /** + * An accessor for `ModelParams.multiroundProfile.maxRoundsNumber`. + */ readonly maxRoundsNumber: number; - readonly proofGenerationContext: ProofGenerationContext; + /** + * Previous proof versions of the current `GeneratedProofImpl` (including the latest one). + * Only the last one (i.e. the latest) is allowed to have an incomplete `ProofVersion`. + * + * When this `GeneratedProofImpl` is generated in a new round (for example, `fixProof` is called), + * the `proofVersions` won't track the results (newer proof versions). + * Completely new `GeneratedProofImpl` objects will be returned, + * having longer `proofVersions` stored inside. + */ readonly proofVersions: ProofVersion[]; + /** + * Creates an instance of `GeneratedProofImpl`. + * Should be called only by `LLMServiceImpl`, `LLMServiceInternal` or `GeneratedProofImpl` itself. + * + * This constructor is capable of extracting the actual proof (its block of code) + * from the input `proof` in case it is contaminated with plain text or any other surrounding symbols. + * Namely, it extracts the block between `Proof.` and `Qed.` if they are present; + * otherwise, takes the whole `proof`. + */ constructor( - proof: Proof, - proofGenerationContext: ProofGenerationContext, - modelParams: ModelParams, - llmService: LLMService, - previousProofVersions?: ProofVersion[] + proof: string, + readonly proofGenerationContext: ProofGenerationContext, + readonly modelParams: ResolvedModelParams, + protected readonly llmServiceInternal: LLMServiceInternalType, + previousProofVersions: ProofVersion[] = [] ) { - // Sometimes, expecially when 0-shot prompting, - // Gpt wraps the proof in a tone of comments and other text. - // The code block is somewhere in the middle. - // This method extracts the code block from the message. - proof = this.parseProofFromMessage(proof); - - this.llmService = llmService; - this.modelParams = modelParams; - - this.proofGenerationContext = proofGenerationContext; - // Makes a copy of the previous proof versions - this.proofVersions = [...(previousProofVersions ?? [])]; + // Make a copy of the previous proof versions + this.proofVersions = [...previousProofVersions]; + + // Save newly generated `proof` this.proofVersions.push({ - proof: proof, + proof: this.removeProofQedIfNeeded(proof), diagnostic: undefined, }); this.maxRoundsNumber = this.modelParams.multiroundProfile.maxRoundsNumber; if (this.maxRoundsNumber < this.proofVersions.length) { - throw new Error( - `proof cannot be generated: max rounds number (${this.maxRoundsNumber}) was already reached` + throw Error( + `proof cannot be instantiated: max rounds number (${this.maxRoundsNumber}) was already reached` ); } } - private lastProofVersion(): ProofVersion { - return this.proofVersions[this.proofVersions.length - 1]; + /** + * @returns proof of the latest version for this `GeneratedProofImpl`. + */ + proof(): string { + return this.lastProofVersion().proof; } - proof(): Proof { - return this.lastProofVersion().proof; + protected lastProofVersion(): ProofVersion { + return this.proofVersions[this.proofVersions.length - 1]; } - // starts with one, then +1 for each version + /** + * Initially generated proofs have version number equal to 1. + * Each generation round creates `GeneratedProofs` with version = `this.versionNumber() + 1`. + * + * @returns version number of this `GeneratedProofImpl`. + */ versionNumber(): number { return this.proofVersions.length; } - protected async generateNextVersion( - chat: ChatHistory, - choices: number - ): Promise { - if (!this.nextVersionCanBeGenerated() || choices <= 0) { - return []; - } - const newProofs = await this.llmService.generateFromChat( - chat, - this.modelParams, - choices - ); - return newProofs.map((proof: string) => - this.llmService.constructGeneratedProof( - proof, - this.proofGenerationContext, - this.modelParams, - this.proofVersions - ) - ); + /** + * This method doesn't check `ModelParams.multiroundProfile.fixedProofChoices`, + * because they can be overriden via the function's parameters at the call. + * + * @returns whether this `GeneratedProofImpl` is allowed to be fixed at least once. + */ + canBeFixed(): Boolean { + return this.nextVersionCanBeGenerated(); } /** - * `modelParams.multiroundProfile.fixedProofChoices` can be overriden here - * with `choices` parameter + * @returns whether `maxRoundsNumber` allows to generate a newer version of this proof. + */ + protected nextVersionCanBeGenerated(): Boolean { + return this.versionNumber() < this.maxRoundsNumber; + } + + /** + * Generates new `GeneratedProofImpl`-s as fixes for the latest version of the current one. + * This method performs errors-handling and logging the same way as `LLMServiceImpl`'s methods do. + * + * When this method is called, the `diagnostic` of the latest proof version + * is overwritten with the `diagnostic` parameter of the call. + * + * The default implementation is based on the generation from chat, namely, + * it calls `LLMServiceInternal.generateFromChatImpl`. + * If it is not the desired way, `fixProof` should be overriden. + * + * @param diagnostic diagnostic received from the compiler. + * @param choices if specified, overrides `ModelParams.multiroundProfile.defaultProofFixChoices`. */ async fixProof( diagnostic: string, - choices: number = this.modelParams.multiroundProfile.proofFixChoices - ): Promise { - if (choices <= 0 || !this.canBeFixed()) { - return []; - } - - const lastProofVersion = this.lastProofVersion(); - assert.ok(lastProofVersion.diagnostic === undefined); - lastProofVersion.diagnostic = diagnostic; - - const chat = buildProofFixChat( - this.proofGenerationContext, - this.proofVersions, - this.modelParams + choices: number = this.modelParams.multiroundProfile + .defaultProofFixChoices, + errorsHandlingMode: ErrorsHandlingMode = ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ): Promise { + return this.llmServiceInternal.generateFromChatWrapped( + this.modelParams, + choices, + errorsHandlingMode, + () => { + if (!this.canBeFixed()) { + throw new ConfigurationError( + `this \`GeneratedProofImpl\` could not be fixed: version ${this.versionNumber()} >= max rounds number ${this.maxRoundsNumber}` + ); + } + this.lastProofVersion().diagnostic = diagnostic; + return buildProofFixChat( + this.proofGenerationContext, + this.proofVersions, + this.modelParams + ); + }, + (proof: string) => + this.llmServiceInternal.constructGeneratedProof( + proof, + this.proofGenerationContext, + this.modelParams, + this.proofVersions + ) ); - - return this.generateNextVersion(chat, choices); } - parseProofFromMessage(message: string): string { - const regex = /Proof(.*?)Qed\./s; - const match = regex.exec(message); + private readonly coqProofBlockPattern = /Proof\.\s*(.*?)\s*Qed\./s; + + private removeProofQedIfNeeded(message: string): string { + const match = this.coqProofBlockPattern.exec(message); if (match) { - return match[0]; + return match[1]; } else { return message; } } - - protected nextVersionCanBeGenerated(): Boolean { - return this.versionNumber() < this.maxRoundsNumber; - } - - // doesn't check this.modelParams.multiroundProfile.fixedProofChoices, because they can be overriden - canBeFixed(): Boolean { - return this.nextVersionCanBeGenerated(); - } } diff --git a/src/llm/llmServices/llmServiceInternal.ts b/src/llm/llmServices/llmServiceInternal.ts new file mode 100644 index 00000000..4134ac88 --- /dev/null +++ b/src/llm/llmServices/llmServiceInternal.ts @@ -0,0 +1,346 @@ +import { EventLogger, Severity } from "../../logging/eventLogger"; +import { + ConfigurationError, + GenerationFailedError, + LLMServiceError, +} from "../llmServiceErrors"; +import { ProofGenerationContext } from "../proofGenerationContext"; +import { UserModelParams } from "../userModelParams"; + +import { AnalyzedChatHistory, ChatHistory } from "./chat"; +import { + ErrorsHandlingMode, + GeneratedProofImpl, + LLMServiceImpl, + LLMServiceRequest, + LLMServiceRequestFailed, + LLMServiceRequestSucceeded, + ProofVersion, +} from "./llmService"; +import { ModelParams } from "./modelParams"; +import { GenerationsLogger } from "./utils/generationsLogger/generationsLogger"; + +/** + * This class represents the inner resources and implementations of `LLMServiceImpl`. + * + * Its main goals are to: + * - separate an actual logic and implementation wrappers from the facade `LLMServiceImpl` class; + * - make `GeneratedProofImpl` effectively an inner class of `LLMServiceImpl`, + * capable of reaching its internal resources. + * + * Also, `LLMServiceInternal` is capable of + * mantaining the `LLMServiceImpl`-s resources and disposing them in the end. + */ +export abstract class LLMServiceInternal< + ResolvedModelParams extends ModelParams, + LLMServiceType extends LLMServiceImpl< + UserModelParams, + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, + GeneratedProofType extends GeneratedProofImpl< + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, + LLMServiceInternalType extends LLMServiceInternal< + ResolvedModelParams, + LLMServiceType, + GeneratedProofType, + LLMServiceInternalType + >, +> { + readonly eventLogger: EventLogger | undefined; + readonly generationsLogger: GenerationsLogger; + readonly debug: DebugWrappers; + + constructor( + readonly llmService: LLMServiceType, + eventLoggerGetter: () => EventLogger | undefined, + generationsLoggerBuilder: () => GenerationsLogger + ) { + this.eventLogger = eventLoggerGetter(); + this.generationsLogger = generationsLoggerBuilder(); + this.debug = new DebugWrappers( + llmService.serviceName, + this.eventLogger + ); + } + + /** + * Basically, this method should just call the constructor + * of the corresponding implementation of the `GeneratedProofImpl`. + * It is needed only to link the service and its proof properly. + */ + abstract constructGeneratedProof( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: ResolvedModelParams, + previousProofVersions?: ProofVersion[] + ): GeneratedProofType; + + /** + * This method should be mostly a pure implementation of + * the generation from chat, namely, its happy path. + * This function doesn't need to handle errors! + * + * In case something goes wrong on the side of the external API, any error can be thrown. + * + * However, if the generation failed due to the invalid configuration of the request + * on the CoqPilot's side (for example: invalid token in `params`), + * this implementation should through `ConfigurationError` whenever possible. + * It is not mandatory, but that way the user will be notified in a clearer way. + * + * Important note: `ResolvedModelParams` are expected to be already validated by `LLMServiceImpl.resolveParameters`, + * so there is no need to perform this checks again. Report `ConfigurationError` only if something goes wrong during generation runtime. + * + * Subnote: most likely you'd like to call `this.validateChoices` to validate `choices` parameter. + * Since it overrides `choices`-like parameters of already validated `params`, it might have any number value. + */ + abstract generateFromChatImpl( + chat: ChatHistory, + params: ResolvedModelParams, + choices: number + ): Promise; + + /** + * All the resources that `LLMServiceInternal` is responsible for should be disposed. + * But only them! + * For example, `this.generationsLogger` is created and maintained by `LLMServiceInternal`, + * so it should be disposed in this method. + * On the contrary, `this.eventLogger` is maintained by the external classes, + * it is only passed to the `LLMServiceImpl`; thus, it should not be disposed here. + */ + dispose(): void { + this.generationsLogger.dispose(); + } + + /** + * Helper function that wraps `LLMServiceInternal.generateFromChatImpl` call with + * logging and errors handling. + * + * To know more about the latter, + * check `LLMServiceInternal.logGenerationAndHandleErrors` docs. + */ + readonly generateFromChatWrapped = async ( + params: ResolvedModelParams, + choices: number, + errorsHandlingMode: ErrorsHandlingMode, + buildAndValidateChat: () => AnalyzedChatHistory, + wrapRawProof: (proof: string) => T + ): Promise => { + return this.logGenerationAndHandleErrors( + params, + choices, + errorsHandlingMode, + (request) => { + request.analyzedChat = buildAndValidateChat(); + }, + async (request) => { + return this.generateFromChatImpl( + request.analyzedChat!.chat, + params, + choices + ); + }, + wrapRawProof + ); + }; + + /** + * This is a helper function that wraps the implementation calls, + * providing generation logging and errors handling. + * Many `LLMServiceImpl` invariants are provided by this function; + * thus, its implementation is final. + * It should be called only in `LLMServiceImpl` or `GeneratedProofImpl`, + * to help with overriding the default public methods implementation. + * + * Invariants TL;DR: + * - any thrown error will be of `LLMServiceError` type: if the error is not of that type originally, it'd be wrapped; + * - errors are rethrown only in case of `RETHROW_ERRORS`; + * - `this.generationsLogger` logs every success and only `GenerationFailedError`-s (not `ConfigurationError`-s, for example); + * - `this.eventLogger` logs every success and in case of `LOG_EVENTS_AND_SWALLOW_ERRORS` logs any error; + * in case of success / failure event's `data` is the `LLMServiceRequestSucceeded` / `LLMServiceRequestFailed` object respectively. + * + * Invariants, the full version. + * - `completeAndValidateRequest` should fill the received request (for example, with `AnalyzedChatHistory`) and validate its properties; + * it is allowed to throw any error: + * - if it is not `ConfigurationError` already, its message will be wrapped into `ConfigurationError`; + * - then, it will be handled according to `errorsHandlingMode` `(*)`; + * - If the request is successfully built, then the proofs generation will be performed. + * - If no error is thrown: + * - generation will be logged as successful one via `this.generationsLogger`; + * - `LLMService.requestSucceededEvent` (with `LLMServiceRequestSucceeded` as data) will be logged via `this.eventLogger`. + * - If error is thrown: + * - it will be wrapped into `GenerationFailedError`, if it is not of `LLMServiceError` type already; + * - if it's an instance of `GenerationFailedError`, it will be logged via `this.generationsLogger`; + * - finally, it will be handled according to `errorsHandlingMode` `(*)`. + * + * `(*)` means: + * - if `errorsHandlingMode === ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS`, + * - `LLMService.requestFailedEvent` (with `LLMServiceRequestFailed` as data + * containing the error wrapped into `LLMServiceError`) will be logged via `this.eventLogger`; + * - the error will not be rethrown. + * - if `errorsHandlingMode === ErrorsHandlingMode.RETHROW_ERRORS`, + * - the error will be rethrown. + */ + readonly logGenerationAndHandleErrors = async ( + params: ResolvedModelParams, + choices: number, + errorsHandlingMode: ErrorsHandlingMode, + completeAndValidateRequest: (request: LLMServiceRequest) => void, + generateProofs: (request: LLMServiceRequest) => Promise, + wrapRawProof: (proof: string) => T + ): Promise => { + const request: LLMServiceRequest = { + llmService: this.llmService, + params: params, + choices: choices, + }; + try { + completeAndValidateRequest(request); + } catch (e) { + const error = LLMServiceInternal.asErrorOrRethrow(e); + const configurationError = + error instanceof ConfigurationError + ? error + : new ConfigurationError(error.message); + this.logAndHandleError( + configurationError, + errorsHandlingMode, + request + ); + return []; + } + try { + const proofs = await generateProofs(request); + this.logSuccess(request, proofs); + return proofs.map(wrapRawProof); + } catch (e) { + const error = LLMServiceInternal.asErrorOrRethrow(e); + this.logAndHandleError(error, errorsHandlingMode, request); + return []; + } + }; + + /** + * Helper function to handle unsupported method properly. + */ + unsupportedMethod( + message: string, + params: ResolvedModelParams, + choices: number, + errorsHandlingMode: ErrorsHandlingMode + ) { + const request: LLMServiceRequest = { + llmService: this.llmService, + params: params, + choices: choices, + }; + this.logAndHandleError( + new ConfigurationError(message), + errorsHandlingMode, + request + ); + } + + /** + * Helper function to validate `choices` are positive. + * + * It is not used in the default implementations, since services + * might handle negative or zero `choices` in some special way. + * However, this validation is most likely needed in any normal `LLMServiceInternal` implementation. + */ + validateChoices(choices: number) { + if (choices <= 0) { + throw new ConfigurationError("choices number should be positive"); + } + } + + private logSuccess( + request: LLMServiceRequest, + generatedRawProofs: string[] + ) { + const requestSucceeded: LLMServiceRequestSucceeded = { + ...request, + generatedRawProofs: generatedRawProofs, + }; + this.generationsLogger.logGenerationSucceeded(requestSucceeded); + this.eventLogger?.logLogicEvent( + LLMServiceImpl.requestSucceededEvent, + requestSucceeded + ); + } + + private static asErrorOrRethrow(e: any): Error { + const error = e as Error; + if (error === null) { + throw e; + } + return error; + } + + private logAndHandleError( + error: Error, + errorsHandlingMode: ErrorsHandlingMode, + request: LLMServiceRequest + ) { + const requestFailed: LLMServiceRequestFailed = { + ...request, + llmServiceError: + error instanceof LLMServiceError + ? error + : new GenerationFailedError(error), + }; + if (requestFailed.llmServiceError instanceof GenerationFailedError) { + this.generationsLogger.logGenerationFailed(requestFailed); + } + this.logAsEventOrRethrow(requestFailed, errorsHandlingMode); + } + + private logAsEventOrRethrow( + requestFailed: LLMServiceRequestFailed, + errorsHandlingMode: ErrorsHandlingMode + ) { + switch (errorsHandlingMode) { + case ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS: + if (!this.eventLogger) { + throw Error("cannot log events: no `eventLogger` provided"); + } + this.eventLogger.logLogicEvent( + LLMServiceImpl.requestFailedEvent, + requestFailed + ); + return; + case ErrorsHandlingMode.RETHROW_ERRORS: + throw requestFailed.llmServiceError; + default: + throw Error( + `unsupported \`ErrorsHandlingMode\`: ${errorsHandlingMode}` + ); + } + } +} + +/** + * Helper object that provides wrappers to write debug logs shorter. + * + * Its instance is available inside `LLMServiceInternal` and + * could be passed into other classes of the internal implementation. + */ +export class DebugWrappers { + constructor( + private readonly serviceName: string, + private readonly eventLogger?: EventLogger + ) {} + + /** + * Helper method that provides debug logging in a shorter way. + */ + logEvent(message: string, data?: any) { + this.eventLogger?.log(this.serviceName, message, data, Severity.DEBUG); + } +} diff --git a/src/llm/llmServices/lmStudio/lmStudioModelParamsResolver.ts b/src/llm/llmServices/lmStudio/lmStudioModelParamsResolver.ts new file mode 100644 index 00000000..e72c0960 --- /dev/null +++ b/src/llm/llmServices/lmStudio/lmStudioModelParamsResolver.ts @@ -0,0 +1,28 @@ +import { LMStudioUserModelParams } from "../../userModelParams"; +import { LMStudioModelParams, lmStudioModelParamsSchema } from "../modelParams"; +import { BasicModelParamsResolver } from "../utils/paramsResolvers/basicModelParamsResolvers"; +import { ValidParamsResolverImpl } from "../utils/paramsResolvers/paramsResolverImpl"; + +export class LMStudioModelParamsResolver + extends BasicModelParamsResolver< + LMStudioUserModelParams, + LMStudioModelParams + > + implements + ValidParamsResolverImpl +{ + constructor() { + super(lmStudioModelParamsSchema, "LMStudioModelParams"); + } + + readonly temperature = this.resolveParam("temperature") + .requiredToBeConfigured() + .validateAtRuntimeOnly(); + + readonly port = this.resolveParam("port") + .requiredToBeConfigured() + .validate([ + (value) => value >= 0 && value <= 65535, + "be a valid port value, i.e. in range between 0 and 65535", + ]); +} diff --git a/src/llm/llmServices/lmStudio/lmStudioService.ts b/src/llm/llmServices/lmStudio/lmStudioService.ts index 9ffd8537..cc288236 100644 --- a/src/llm/llmServices/lmStudio/lmStudioService.ts +++ b/src/llm/llmServices/lmStudio/lmStudioService.ts @@ -1,19 +1,80 @@ -import { EventLogger, Severity } from "../../../logging/eventLogger"; +import { EventLogger } from "../../../logging/eventLogger"; import { ProofGenerationContext } from "../../proofGenerationContext"; +import { LMStudioUserModelParams } from "../../userModelParams"; import { ChatHistory } from "../chat"; -import { GeneratedProof, LLMService, Proof, ProofVersion } from "../llmService"; +import { + GeneratedProofImpl, + LLMServiceImpl, + ProofVersion, +} from "../llmService"; +import { LLMServiceInternal } from "../llmServiceInternal"; import { LMStudioModelParams } from "../modelParams"; -export class LMStudioService extends LLMService { - constructor(readonly eventLogger?: EventLogger) { - super(eventLogger); +import { LMStudioModelParamsResolver } from "./lmStudioModelParamsResolver"; + +export class LMStudioService extends LLMServiceImpl< + LMStudioUserModelParams, + LMStudioModelParams, + LMStudioService, + LMStudioGeneratedProof, + LMStudioServiceInternal +> { + protected readonly internal: LMStudioServiceInternal; + protected readonly modelParamsResolver = new LMStudioModelParamsResolver(); + + constructor( + eventLogger?: EventLogger, + debugLogs: boolean = false, + generationsLogsFilePath?: string + ) { + super( + "LMStudioService", + eventLogger, + debugLogs, + generationsLogsFilePath + ); + this.internal = new LMStudioServiceInternal( + this, + this.eventLoggerGetter, + this.generationsLoggerBuilder + ); } +} - constructGeneratedProof( +export class LMStudioGeneratedProof extends GeneratedProofImpl< + LMStudioModelParams, + LMStudioService, + LMStudioGeneratedProof, + LMStudioServiceInternal +> { + constructor( proof: string, proofGenerationContext: ProofGenerationContext, modelParams: LMStudioModelParams, + llmServiceInternal: LMStudioServiceInternal, previousProofVersions?: ProofVersion[] + ) { + super( + proof, + proofGenerationContext, + modelParams, + llmServiceInternal, + previousProofVersions + ); + } +} + +class LMStudioServiceInternal extends LLMServiceInternal< + LMStudioModelParams, + LMStudioService, + LMStudioGeneratedProof, + LMStudioServiceInternal +> { + constructGeneratedProof( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: LMStudioModelParams, + previousProofVersions?: ProofVersion[] | undefined ): LMStudioGeneratedProof { return new LMStudioGeneratedProof( proof, @@ -24,20 +85,17 @@ export class LMStudioService extends LLMService { ); } - async generateFromChat( + async generateFromChatImpl( chat: ChatHistory, params: LMStudioModelParams, choices: number ): Promise { - this.eventLogger?.log( - "lm-studio-fetch-started", - "Completion from LmStudio requested", - { history: chat }, - Severity.DEBUG - ); + this.validateChoices(choices); let attempts = choices * 2; const completions: string[] = []; + this.debug.logEvent("Completion requested", { history: chat }); + let lastErrorThrown: Error | undefined = undefined; while (completions.length < choices && attempts > 0) { try { const responce = await fetch(this.endpoint(params), { @@ -47,23 +105,27 @@ export class LMStudioService extends LLMService { }); if (responce.ok) { const res = await responce.json(); - completions.push(res.choices[0].message.content); + const newCompletion = res.choices[0].message.content; + completions.push(newCompletion); + this.debug.logEvent("Completion succeeded", { + newCompletion: newCompletion, + }); } - this.eventLogger?.log( - "lm-studio-fetch-success", - "Completion from LmStudio succeeded", - { completions: completions } - ); } catch (err) { - this.eventLogger?.log( - "lm-studio-fetch-failed", - "Completion from LmStudio failed", - { error: err } - ); + this.debug.logEvent("Completion failed", { + error: err, + }); + if ((err as Error) === null) { + throw err; + } + lastErrorThrown = err as Error; } attempts--; } + if (completions.length < choices) { + throw lastErrorThrown; + } return completions; } @@ -72,13 +134,13 @@ export class LMStudioService extends LLMService { "Content-Type": "application/json", }; - private body(messages: ChatHistory, params: LMStudioModelParams): any { + private body(messages: ChatHistory, params: LMStudioModelParams): string { return JSON.stringify({ messages: messages, stream: false, temperature: params.temperature, // eslint-disable-next-line @typescript-eslint/naming-convention - max_tokens: params.newMessageMaxTokens, + max_tokens: params.maxTokensToGenerate, }); } @@ -86,21 +148,3 @@ export class LMStudioService extends LLMService { return `http://localhost:${params.port}/v1/chat/completions`; } } - -export class LMStudioGeneratedProof extends GeneratedProof { - constructor( - proof: Proof, - proofGenerationContext: ProofGenerationContext, - modelParams: LMStudioModelParams, - llmService: LMStudioService, - previousProofVersions?: ProofVersion[] - ) { - super( - proof, - proofGenerationContext, - modelParams, - llmService, - previousProofVersions - ); - } -} diff --git a/src/llm/llmServices/modelParams.ts b/src/llm/llmServices/modelParams.ts index f71addf0..9712f3ec 100644 --- a/src/llm/llmServices/modelParams.ts +++ b/src/llm/llmServices/modelParams.ts @@ -1,32 +1,153 @@ +import { JSONSchemaType } from "ajv"; +import { PropertiesSchema } from "ajv/dist/types/json-schema"; + export interface MultiroundProfile { maxRoundsNumber: number; - proofFixChoices: number; + /** + * Is handled the same way as `ModelParams.defaultChoices` is, i.e. `defaultProofFixChoices` is used + * only as a default `choices` value in the corresponding `fixProof` facade method. + * + * Do not use it inside the implementation, use the `choices` instead. + */ + defaultProofFixChoices: number; proofFixPrompt: string; } export interface ModelParams { - modelName: string; + modelId: string; systemPrompt: string; - newMessageMaxTokens: number; + maxTokensToGenerate: number; tokensLimit: number; multiroundProfile: MultiroundProfile; + + /** + * Always overriden by the `choices` parameter at the call site, if one is specified. + * I.e. `defaultChoices` is used only as a default `choices` value in the corresponding facade methods. + * + * Do not use it inside the implementation, use the `choices` instead. + */ + defaultChoices: number; +} + +export interface PredefinedProofsModelParams extends ModelParams { + tactics: string[]; } export interface OpenAiModelParams extends ModelParams { + modelName: string; temperature: number; apiKey: string; } export interface GrazieModelParams extends ModelParams { + modelName: string; apiKey: string; } -export interface PredefinedProofsModelParams extends ModelParams { - // A list of tactics to try to solve the goal with. - tactics: string[]; -} - export interface LMStudioModelParams extends ModelParams { temperature: number; port: number; } + +export interface ModelsParams { + predefinedProofsModelParams: PredefinedProofsModelParams[]; + openAiParams: OpenAiModelParams[]; + grazieParams: GrazieModelParams[]; + lmStudioParams: LMStudioModelParams[]; +} + +export const multiroundProfileSchema: JSONSchemaType = { + type: "object", + properties: { + maxRoundsNumber: { type: "number" }, + defaultProofFixChoices: { type: "number" }, + proofFixPrompt: { type: "string" }, + }, + required: ["maxRoundsNumber", "defaultProofFixChoices", "proofFixPrompt"], + additionalProperties: false, +}; + +export const modelParamsSchema: JSONSchemaType = { + type: "object", + properties: { + modelId: { type: "string" }, + + systemPrompt: { type: "string" }, + + maxTokensToGenerate: { type: "number" }, + tokensLimit: { type: "number" }, + + multiroundProfile: { + type: "object", + oneOf: [multiroundProfileSchema], + }, + + defaultChoices: { type: "number" }, + }, + required: [ + "modelId", + "systemPrompt", + "maxTokensToGenerate", + "tokensLimit", + "multiroundProfile", + "defaultChoices", + ], + additionalProperties: false, +}; + +export const predefinedProofsModelParamsSchema: JSONSchemaType = + { + title: "predefinedProofsModelsParameters", + type: "object", + properties: { + tactics: { + type: "array", + items: { type: "string" }, + }, + ...(modelParamsSchema.properties as PropertiesSchema), + }, + required: ["tactics", ...modelParamsSchema.required], + additionalProperties: false, + }; + +export const openAiModelParamsSchema: JSONSchemaType = { + title: "openAiModelsParameters", + type: "object", + properties: { + modelName: { type: "string" }, + temperature: { type: "number" }, + apiKey: { type: "string" }, + ...(modelParamsSchema.properties as PropertiesSchema), + }, + required: [ + "modelName", + "temperature", + "apiKey", + ...modelParamsSchema.required, + ], + additionalProperties: false, +}; + +export const grazieModelParamsSchema: JSONSchemaType = { + title: "grazieModelsParameters", + type: "object", + properties: { + modelName: { type: "string" }, + apiKey: { type: "string" }, + ...(modelParamsSchema.properties as PropertiesSchema), + }, + required: ["modelName", "apiKey", ...modelParamsSchema.required], + additionalProperties: false, +}; + +export const lmStudioModelParamsSchema: JSONSchemaType = { + title: "lmStudioModelsParameters", + type: "object", + properties: { + temperature: { type: "number" }, + port: { type: "number" }, + ...(modelParamsSchema.properties as PropertiesSchema), + }, + required: ["temperature", "port", ...modelParamsSchema.required], + additionalProperties: false, +}; diff --git a/src/llm/llmServices/openai/openAiModelParamsResolver.ts b/src/llm/llmServices/openai/openAiModelParamsResolver.ts new file mode 100644 index 00000000..664032ae --- /dev/null +++ b/src/llm/llmServices/openai/openAiModelParamsResolver.ts @@ -0,0 +1,173 @@ +import { OpenAiUserModelParams } from "../../userModelParams"; +import { OpenAiModelParams, openAiModelParamsSchema } from "../modelParams"; +import { BasicModelParamsResolver } from "../utils/paramsResolvers/basicModelParamsResolvers"; +import { ValidationRules } from "../utils/paramsResolvers/builders"; +import { ValidParamsResolverImpl } from "../utils/paramsResolvers/paramsResolverImpl"; + +export class OpenAiModelParamsResolver + extends BasicModelParamsResolver + implements + ValidParamsResolverImpl +{ + constructor() { + super(openAiModelParamsSchema, "OpenAiModelParams"); + } + + readonly modelName = this.resolveParam("modelName") + .requiredToBeConfigured() + .validateAtRuntimeOnly(); + + readonly temperature = this.resolveParam("temperature") + .requiredToBeConfigured() + .validate([ + (value) => value >= 0 && value <= 2, + "be in range between 0 and 2", + ]); + + readonly apiKey = this.resolveParam("apiKey") + .requiredToBeConfigured() + .validateAtRuntimeOnly(); + + readonly tokensLimit = this.resolveParam("tokensLimit") + .default((inputParams) => + OpenAiModelParamsResolver._modelToTokensLimit.get( + inputParams.modelName + ) + ) + .validate(ValidationRules.bePositiveNumber, [ + (value, inputParams) => { + const actualTokensLimit = + OpenAiModelParamsResolver._modelToTokensLimit.get( + inputParams.modelName + ); + if ( + actualTokensLimit === undefined || + value <= actualTokensLimit + ) { + return true; + } + return false; + }, + (inputParams) => + `be not greater than the known tokens limit (${OpenAiModelParamsResolver._modelToTokensLimit.get(inputParams.modelName)}) for the "${inputParams.modelName}" model`, + ]); + + /** + * Since the actual maximum numbers of tokens that the models can generate are sometimes equal to their token limits, + * a special algorithm to suggest a proper practical default value is used. + * - If `actualTokensLimit` is twice or more times greater than `actualMaxTokensToGenerate`, return the actual value. + * - Otherwise, return minimum of `actualTokensLimit` / 2 and 4096. + * + * Of course, if the model is unknown to the resolver, no default resolution will happen. + */ + readonly maxTokensToGenerate = this.resolveParam( + "maxTokensToGenerate" + ) + .default((inputParams) => { + const actualMaxTokensToGenerate = + OpenAiModelParamsResolver._modelToMaxTokensToGenerate.get( + inputParams.modelName + ); + const actualTokensLimit = + inputParams.tokensLimit ?? + OpenAiModelParamsResolver._modelToTokensLimit.get( + inputParams.modelName + ); + if ( + actualMaxTokensToGenerate === undefined || + actualTokensLimit === undefined + ) { + return undefined; + } + if (2 * actualMaxTokensToGenerate < actualTokensLimit) { + return actualMaxTokensToGenerate; + } + const halfTokensLimit = Math.floor(actualTokensLimit / 2); + return Math.min(halfTokensLimit, 4096); + }) + .validate(ValidationRules.bePositiveNumber, [ + (value, inputParams) => { + const actualMaxTokensToGenerate = + OpenAiModelParamsResolver._modelToMaxTokensToGenerate.get( + inputParams.modelName + ); + if ( + actualMaxTokensToGenerate === undefined || + value <= actualMaxTokensToGenerate + ) { + return true; + } + return false; + }, + (inputParams) => + `be not greater than the known max tokens to generate limit (${OpenAiModelParamsResolver._modelToMaxTokensToGenerate.get(inputParams.modelName)}) for the "${inputParams.modelName}" model`, + ]); + + /* + * About default tokens parameters (both `_modelToTokensLimit` and `_modelToMaxTokensToGenerate`). + * The values are taken mostly from the official OpenAI docs: https://platform.openai.com/docs/models. + * However, the information there is incomplete. Thus, external resources were used for some models. + * Such records are marked with the comments containing the reference to the source. + * - (*) = the post from the OpenAI community: https://community.openai.com/t/request-query-for-a-models-max-tokens/161891/8. + * - (sources) = Python OpenAI api sources: https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/openai.py. + * - (microsoft) = Azure OpenAI Service models: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4o-and-gpt-4-turbo. + */ + + static readonly _modelToTokensLimit: Map = new Map([ + ["gpt-4o", 4096], + ["gpt-4o-2024-05-13", 4096], + ["gpt-4-turbo", 128_000], + ["gpt-4-turbo-2024-04-09", 128_000], + ["gpt-4-turbo-preview", 128_000], + ["gpt-4-0125-preview", 128_000], + ["gpt-4-1106-preview", 128_000], + ["gpt-4-vision-preview", 128_000], + ["gpt-4-1106-vision-preview", 128_000], + ["gpt-4", 8192], + ["gpt-4-0314", 8192], // (*), (microsoft) + ["gpt-4-0613", 8192], + ["gpt-4-32k", 32_768], + ["gpt-4-32k-0314", 32_768], // (*), (microsoft) + ["gpt-4-32k-0613", 32_768], + ["gpt-3.5-turbo-0125", 16_385], + ["gpt-3.5-turbo", 16_385], + ["gpt-3.5-turbo-1106", 16_385], + ["gpt-3.5-turbo-instruct", 4096], + ["gpt-3.5-turbo-16k", 16_385], + ["gpt-3.5-turbo-16k-0613", 16_385], + ["gpt-3.5-turbo-0613", 4096], + ["gpt-3.5-turbo-0301", 4096], + ]); + + /** + * These are the actual maximum numbers of tokens that these models can generate. + * However, sometimes these values are equal to the corresponding tokens limits, + * so it would be impractical to set `maxTokensToGenerate` to their values. + * Thus, the default resolver should check this case and suggest smaller values if possible. + */ + static readonly _modelToMaxTokensToGenerate: Map = new Map([ + ["gpt-4o", 4096], // (microsoft) + ["gpt-4o-2024-05-13", 4096], // (microsoft) + ["gpt-4-turbo", 4096], // (microsoft) + ["gpt-4-turbo-2024-04-09", 4096], // (microsoft) + ["gpt-4-turbo-preview", 4096], + ["gpt-4-0125-preview", 4096], + ["gpt-4-1106-preview", 4096], + ["gpt-4-vision-preview", 4096], + ["gpt-4-1106-vision-preview", 4096], + ["gpt-4", 8192], // (*), (sources) + ["gpt-4-0314", 8192], // (*), (sources), (microsoft) + ["gpt-4-0613", 8192], // (*), (sources), (microsoft) + ["gpt-4-32k", 32_768], // (*), (sources) + ["gpt-4-32k-0314", 32_768], // (*), (sources), (microsoft) + ["gpt-4-32k-0613", 32_768], // (*), (sources), (microsoft) + ["gpt-3.5-turbo-0125", 4096], + ["gpt-3.5-turbo", 4096], + ["gpt-3.5-turbo-1106", 4096], + ["gpt-3.5-turbo-instruct", 4096], // (sources) + ["gpt-3.5-turbo-16k", 16_385], // (*), (sources) + ["gpt-3.5-turbo-16k-0613", 16_385], // (*), (sources), (microsoft) + ["gpt-3.5-turbo-0613", 4096], // (*), (sources), (microsoft) + ["gpt-3.5-turbo-0301", 4096], // (*), (sources), (microsoft) + ]); +} diff --git a/src/llm/llmServices/openai/openAiService.ts b/src/llm/llmServices/openai/openAiService.ts index e41ccf4a..9717690e 100644 --- a/src/llm/llmServices/openai/openAiService.ts +++ b/src/llm/llmServices/openai/openAiService.ts @@ -1,76 +1,189 @@ import OpenAI from "openai"; -import { EventLogger, Severity } from "../../../logging/eventLogger"; +import { EventLogger } from "../../../logging/eventLogger"; +import { + ConfigurationError, + RemoteConnectionError, +} from "../../llmServiceErrors"; import { ProofGenerationContext } from "../../proofGenerationContext"; +import { OpenAiUserModelParams } from "../../userModelParams"; import { ChatHistory } from "../chat"; -import { GeneratedProof, LLMService, ProofVersion } from "../llmService"; -import { Proof } from "../llmService"; -import { ModelParams, OpenAiModelParams } from "../modelParams"; +import { + GeneratedProofImpl, + LLMServiceImpl, + ProofVersion, +} from "../llmService"; +import { LLMServiceInternal } from "../llmServiceInternal"; +import { OpenAiModelParams } from "../modelParams"; -export class OpenAiService extends LLMService { - constructor(eventLogger?: EventLogger) { - super(eventLogger); +import { OpenAiModelParamsResolver } from "./openAiModelParamsResolver"; + +export class OpenAiService extends LLMServiceImpl< + OpenAiUserModelParams, + OpenAiModelParams, + OpenAiService, + OpenAiGeneratedProof, + OpenAiServiceInternal +> { + protected readonly internal: OpenAiServiceInternal; + protected readonly modelParamsResolver = new OpenAiModelParamsResolver(); + + constructor( + eventLogger?: EventLogger, + debugLogs: boolean = false, + generationsLogsFilePath?: string + ) { + super("OpenAiService", eventLogger, debugLogs, generationsLogsFilePath); + this.internal = new OpenAiServiceInternal( + this, + this.eventLoggerGetter, + this.generationsLoggerBuilder + ); } +} - constructGeneratedProof( +export class OpenAiGeneratedProof extends GeneratedProofImpl< + OpenAiModelParams, + OpenAiService, + OpenAiGeneratedProof, + OpenAiServiceInternal +> { + constructor( proof: string, proofGenerationContext: ProofGenerationContext, - modelParams: ModelParams, + modelParams: OpenAiModelParams, + llmServiceInternal: OpenAiServiceInternal, previousProofVersions?: ProofVersion[] - ): GeneratedProof { + ) { + super( + proof, + proofGenerationContext, + modelParams, + llmServiceInternal, + previousProofVersions + ); + } +} + +class OpenAiServiceInternal extends LLMServiceInternal< + OpenAiModelParams, + OpenAiService, + OpenAiGeneratedProof, + OpenAiServiceInternal +> { + constructGeneratedProof( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: OpenAiModelParams, + previousProofVersions?: ProofVersion[] | undefined + ): OpenAiGeneratedProof { return new OpenAiGeneratedProof( proof, proofGenerationContext, - modelParams as OpenAiModelParams, + modelParams, this, previousProofVersions ); } - async generateFromChat( + async generateFromChatImpl( chat: ChatHistory, - params: ModelParams, + params: OpenAiModelParams, choices: number ): Promise { - if (choices <= 0) { - return []; + this.validateChoices(choices); + + const openai = new OpenAI({ apiKey: params.apiKey }); + this.debug.logEvent("Completion requested", { history: chat }); + + try { + const completion = await openai.chat.completions.create({ + messages: chat, + model: params.modelName, + n: choices, + temperature: params.temperature, + // eslint-disable-next-line @typescript-eslint/naming-convention + max_tokens: params.maxTokensToGenerate, + }); + return completion.choices.map((choice) => { + const content = choice.message.content; + if (content === null) { + throw Error("response message content is null"); + } + return content; + }); + } catch (e) { + throw OpenAiServiceInternal.repackKnownError(e, params); } - const openAiParams = params as OpenAiModelParams; - const openai = new OpenAI({ apiKey: openAiParams.apiKey }); - - this.eventLogger?.log( - "openai-fetch-started", - "Generate with OpenAI", - { history: chat }, - Severity.DEBUG - ); - const completion = await openai.chat.completions.create({ - messages: chat, - model: openAiParams.modelName, - n: choices, - temperature: openAiParams.temperature, - // eslint-disable-next-line @typescript-eslint/naming-convention - max_tokens: openAiParams.newMessageMaxTokens, - }); - - return completion.choices.map((choice: any) => choice.message.content); } -} -export class OpenAiGeneratedProof extends GeneratedProof { - constructor( - proof: Proof, - proofGenerationContext: ProofGenerationContext, - modelParams: OpenAiModelParams, - llmService: OpenAiService, - previousProofVersions?: ProofVersion[] - ) { - super( - proof, - proofGenerationContext, - modelParams, - llmService, - previousProofVersions + private static repackKnownError( + caughtObject: any, + params: OpenAiModelParams + ): any { + const error = caughtObject as Error; + if (error === null) { + return caughtObject; + } + const errorMessage = error.message; + + if (this.matchesPattern(this.unknownModelNamePattern, errorMessage)) { + return new ConfigurationError( + `invalid model name "${params.modelName}", such model does not exist or you do not have access to it` + ); + } + if (this.matchesPattern(this.incorrectApiKeyPattern, errorMessage)) { + return new ConfigurationError( + `incorrect api key "${params.apiKey}" (check your API key at https://platform.openai.com/account/api-keys)` + ); + } + const contextExceeded = this.parsePattern( + this.maximumContextLengthExceededPattern, + errorMessage ); + if (contextExceeded !== undefined) { + const [ + modelsMaxContextLength, + requestedTokens, + requestedMessagesTokens, + maxTokensToGenerate, + ] = contextExceeded; + const intro = + "`tokensLimit` and `maxTokensToGenerate` are too large together"; + const explanation = `model's maximum context length is ${modelsMaxContextLength} tokens, but was requested ${requestedTokens} tokens = ${requestedMessagesTokens} in the messages + ${maxTokensToGenerate} in the completion`; + return new ConfigurationError(`${intro}; ${explanation}`); + } + if (this.matchesPattern(this.connectionErrorPattern, errorMessage)) { + return new RemoteConnectionError( + "failed to reach OpenAI remote service" + ); + } + return error; } + + private static matchesPattern(pattern: RegExp, text: string): boolean { + return text.match(pattern) !== null; + } + + private static parsePattern( + pattern: RegExp, + text: string + ): string[] | undefined { + const match = text.match(pattern); + if (!match) { + return undefined; + } + return match.slice(1); + } + + private static readonly unknownModelNamePattern = + /^The model `(.*)` does not exist or you do not have access to it\.$/; + + private static readonly incorrectApiKeyPattern = + /^Incorrect API key provided: (.*)\.(.*)$/; + + private static readonly maximumContextLengthExceededPattern = + /^This model's maximum context length is ([0-9]+) tokens\. However, you requested ([0-9]+) tokens \(([0-9]+) in the messages, ([0-9]+) in the completion\)\..*$/; + + private static readonly connectionErrorPattern = /^Connection error\.$/; } diff --git a/src/llm/llmServices/predefinedProofs/predefinedProofsModelParamsResolver.ts b/src/llm/llmServices/predefinedProofs/predefinedProofsModelParamsResolver.ts new file mode 100644 index 00000000..d77679a6 --- /dev/null +++ b/src/llm/llmServices/predefinedProofs/predefinedProofsModelParamsResolver.ts @@ -0,0 +1,67 @@ +import { PredefinedProofsUserModelParams } from "../../userModelParams"; +import { + MultiroundProfile, + PredefinedProofsModelParams, + predefinedProofsModelParamsSchema, +} from "../modelParams"; +import { BasicModelParamsResolver } from "../utils/paramsResolvers/basicModelParamsResolvers"; +import { ValidParamsResolverImpl } from "../utils/paramsResolvers/paramsResolverImpl"; + +export class PredefinedProofsModelParamsResolver + extends BasicModelParamsResolver< + PredefinedProofsUserModelParams, + PredefinedProofsModelParams + > + implements + ValidParamsResolverImpl< + PredefinedProofsUserModelParams, + PredefinedProofsModelParams + > +{ + constructor() { + super(predefinedProofsModelParamsSchema, "PredefinedProofsModelParams"); + } + + readonly tactics = this.resolveParam("tactics") + .requiredToBeConfigured() + .validate([(value) => value.length > 0, "be non-empty"]); + + readonly systemPrompt = this.resolveParam( + "systemPrompt" + ).overrideWithMock(() => ""); + + readonly maxTokensToGenerate = this.resolveParam( + "maxTokensToGenerate" + ).overrideWithMock((inputParams) => + Math.max(0, ...inputParams.tactics.map((tactic) => tactic.length)) + ); + + readonly tokensLimit = this.resolveParam( + "tokensLimit" + ).overrideWithMock(() => Number.MAX_SAFE_INTEGER); + + readonly multiroundProfile = this.resolveParam( + "multiroundProfile" + ).overrideWithMock(() => { + return { + maxRoundsNumber: 1, + defaultProofFixChoices: 0, + proofFixPrompt: "", + }; + }); + + readonly defaultChoices = this.resolveParam("choices") + .override( + (inputParams) => inputParams.tactics.length, + `always equals to the total number of \`tactics\`` + ) + .requiredToBeConfigured() + .validate( + [(value) => value >= 0, "be non-negative"], + [ + (value, inputParams) => value <= inputParams.tactics.length, + (inputParams) => + `be less than or equal to the total number of \`tactics\` (${inputParams.tactics.length} for the specified \`tactics\`)`, + ] + ); +} diff --git a/src/llm/llmServices/predefinedProofs/predefinedProofsService.ts b/src/llm/llmServices/predefinedProofs/predefinedProofsService.ts index 5dfa7dde..e4aff5ec 100644 --- a/src/llm/llmServices/predefinedProofs/predefinedProofsService.ts +++ b/src/llm/llmServices/predefinedProofs/predefinedProofsService.ts @@ -1,65 +1,78 @@ import { EventLogger } from "../../../logging/eventLogger"; +import { ConfigurationError } from "../../llmServiceErrors"; import { ProofGenerationContext } from "../../proofGenerationContext"; -import { - PredefinedProofsUserModelParams, - UserModelParams, -} from "../../userModelParams"; +import { PredefinedProofsUserModelParams } from "../../userModelParams"; import { ChatHistory } from "../chat"; -import { GeneratedProof, Proof, ProofVersion } from "../llmService"; -import { LLMService } from "../llmService"; -import { ModelParams, PredefinedProofsModelParams } from "../modelParams"; +import { + ErrorsHandlingMode, + GeneratedProofImpl, + ProofVersion, +} from "../llmService"; +import { LLMServiceImpl } from "../llmService"; +import { LLMServiceInternal } from "../llmServiceInternal"; +import { PredefinedProofsModelParams } from "../modelParams"; +import { Time, timeZero } from "../utils/time"; -export class PredefinedProofsService extends LLMService { - constructor(eventLogger?: EventLogger) { - super(eventLogger); - } +import { PredefinedProofsModelParamsResolver } from "./predefinedProofsModelParamsResolver"; - constructGeneratedProof( - proof: string, - proofGenerationContext: ProofGenerationContext, - modelParams: ModelParams, - _previousProofVersions?: ProofVersion[] - ): GeneratedProof { - return new PredefinedProof( - proof, - proofGenerationContext, - modelParams as PredefinedProofsModelParams, - this - ); - } +export class PredefinedProofsService extends LLMServiceImpl< + PredefinedProofsUserModelParams, + PredefinedProofsModelParams, + PredefinedProofsService, + PredefinedProof, + PredefinedProofsServiceInternal +> { + protected readonly internal: PredefinedProofsServiceInternal; + protected readonly modelParamsResolver = + new PredefinedProofsModelParamsResolver(); - generateFromChat( - _chat: ChatHistory, - _params: ModelParams, - _choices: number - ): Promise { - throw new Error( - "PredefinedProofsService does not support generation from chat" + constructor( + eventLogger?: EventLogger, + debugLogs: boolean = false, + generationsLogsFilePath?: string + ) { + super( + "PredefinedProofsService", + eventLogger, + debugLogs, + generationsLogsFilePath + ); + this.internal = new PredefinedProofsServiceInternal( + this, + this.eventLoggerGetter, + this.generationsLoggerBuilder ); } async generateProof( proofGenerationContext: ProofGenerationContext, - params: ModelParams, - choices: number - ): Promise { - if (choices <= 0) { - return []; - } - const predefinedProofsParams = params as PredefinedProofsModelParams; - const tactics = predefinedProofsParams.tactics; - if (choices > tactics.length) { - throw Error( - `invalid choices ${choices}: there are only ${tactics.length} predefined tactics available` - ); - } - return this.formatCoqSentences(tactics.slice(0, choices)).map( - (tactic) => - new PredefinedProof( - `Proof. ${tactic} Qed.`, + params: PredefinedProofsModelParams, + choices: number = params.defaultChoices, + errorsHandlingMode: ErrorsHandlingMode = ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ): Promise { + return this.internal.logGenerationAndHandleErrors( + params, + choices, + errorsHandlingMode, + (_request) => { + this.internal.validateChoices(choices); + const tactics = params.tactics; + if (choices > tactics.length) { + throw new ConfigurationError( + `requested ${choices} choices, there are only ${tactics.length} predefined tactics available` + ); + } + }, + async (_request) => { + return this.formatCoqSentences( + params.tactics.slice(0, choices) + ).map((tactic) => `Proof. ${tactic} Qed.`); + }, + (proof) => + this.internal.constructGeneratedProof( + proof, proofGenerationContext, - predefinedProofsParams, - this + params ) ); } @@ -74,51 +87,73 @@ export class PredefinedProofsService extends LLMService { }); } - resolveParameters(params: UserModelParams): ModelParams { - const castedParams = params as PredefinedProofsUserModelParams; - if (castedParams.tactics.length === 0) { - throw Error( - "no tactics are selected in the PredefinedProofsModelParams" - ); - } - const modelParams: PredefinedProofsModelParams = { - modelName: params.modelName, - newMessageMaxTokens: Math.max( - ...castedParams.tactics.map((tactic) => tactic.length) - ), - tokensLimit: Number.POSITIVE_INFINITY, - systemPrompt: "", - multiroundProfile: { - maxRoundsNumber: 1, - proofFixChoices: 0, - proofFixPrompt: "", - }, - tactics: castedParams.tactics, - }; - return modelParams; + estimateTimeToBecomeAvailable(): Time { + return timeZero; // predefined proofs are always available } } -export class PredefinedProof extends GeneratedProof { +export class PredefinedProof extends GeneratedProofImpl< + PredefinedProofsModelParams, + PredefinedProofsService, + PredefinedProof, + PredefinedProofsServiceInternal +> { constructor( - proof: Proof, + proof: string, proofGenerationContext: ProofGenerationContext, modelParams: PredefinedProofsModelParams, - llmService: PredefinedProofsService + llmServiceInternal: PredefinedProofsServiceInternal ) { - super(proof, proofGenerationContext, modelParams, llmService); + super(proof, proofGenerationContext, modelParams, llmServiceInternal); } - protected generateNextVersion( - _chat: ChatHistory, - _choices: number - ): Promise { - throw new Error( - "PredefinedProof does not support next version generation" + async fixProof( + _diagnostic: string, + choices: number = this.modelParams.multiroundProfile + .defaultProofFixChoices, + errorsHandlingMode: ErrorsHandlingMode + ): Promise { + this.llmServiceInternal.unsupportedMethod( + "`PredefinedProof` cannot be fixed", + this.modelParams, + choices, + errorsHandlingMode + ); + return []; + } + + canBeFixed(): Boolean { + return false; + } +} + +class PredefinedProofsServiceInternal extends LLMServiceInternal< + PredefinedProofsModelParams, + PredefinedProofsService, + PredefinedProof, + PredefinedProofsServiceInternal +> { + constructGeneratedProof( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: PredefinedProofsModelParams, + _previousProofVersions?: ProofVersion[] + ): PredefinedProof { + return new PredefinedProof( + proof, + proofGenerationContext, + modelParams as PredefinedProofsModelParams, + this ); } - fixProof(_diagnostic: string, _choices: number): Promise { - throw new Error("PredefinedProof cannot be fixed"); + generateFromChatImpl( + _chat: ChatHistory, + _params: PredefinedProofsModelParams, + _choices: number + ): Promise { + throw new ConfigurationError( + "`PredefinedProofsService` does not support generation from chat" + ); } } diff --git a/src/llm/llmServices/utils/chatFactory.ts b/src/llm/llmServices/utils/chatFactory.ts index ab3498f2..0b84d5de 100644 --- a/src/llm/llmServices/utils/chatFactory.ts +++ b/src/llm/llmServices/utils/chatFactory.ts @@ -1,8 +1,7 @@ -import * as assert from "assert"; - import { Theorem } from "../../../coqParser/parsedTypes"; +import { ConfigurationError } from "../../llmServiceErrors"; import { ProofGenerationContext } from "../../proofGenerationContext"; -import { ChatHistory, ChatMessage } from "../chat"; +import { AnalyzedChatHistory, ChatHistory, ChatMessage } from "../chat"; import { ProofVersion } from "../llmService"; import { ModelParams } from "../modelParams"; @@ -12,6 +11,7 @@ import { chatItemToContent, itemizedChatToHistory, } from "./chatUtils"; +import { modelName } from "./modelParamsAccessors"; export function validateChat(chat: ChatHistory): [boolean, string] { if (chat.length < 1) { @@ -47,107 +47,131 @@ export function buildChat( let chat: ChatHistory = []; chat = chat.concat(...chats); const [isValid, errorMessage] = validateChat(chat); - assert.ok(isValid, errorMessage); + if (!isValid) { + throw new ConfigurationError( + `built chat is invalid: ${errorMessage};\n\`${chat}\`` + ); + } return chat; } -export function buildProofGenerationChat( - proofGenerationContext: ProofGenerationContext, - modelParams: ModelParams -): ChatHistory { - const fitter = new ChatTokensFitter( - modelParams.modelName, - modelParams.newMessageMaxTokens, - modelParams.tokensLimit - ); - - const systemMessage: ChatMessage = { - role: "system", - content: modelParams.systemPrompt, - }; - fitter.fitRequiredMessage(systemMessage); - - const completionTargetMessage: ChatMessage = { - role: "user", - content: proofGenerationContext.completionTarget, +export function buildAndAnalyzeChat( + fitter: ChatTokensFitter, + ...chats: (ChatHistory | ChatMessage)[] +): AnalyzedChatHistory { + return { + chat: buildChat(...chats), + estimatedTokens: fitter.estimateTokens(), }; - fitter.fitRequiredMessage(completionTargetMessage); +} - const fittedContextTheorems = fitter.fitOptionalObjects( - proofGenerationContext.contextTheorems, - (theorem) => chatItemToContent(theoremToChatItem(theorem)) +function withFitter( + modelParams: ModelParams, + block: (fitter: ChatTokensFitter) => T +): T { + const fitter = new ChatTokensFitter( + modelParams.maxTokensToGenerate, + modelParams.tokensLimit, + modelName(modelParams) ); - const contextTheoremsChat = buildTheoremsChat(fittedContextTheorems); + try { + return block(fitter); + } finally { + fitter.dispose(); + } +} - return buildChat( - systemMessage, - contextTheoremsChat, - completionTargetMessage - ); +export function buildProofGenerationChat( + proofGenerationContext: ProofGenerationContext, + modelParams: ModelParams +): AnalyzedChatHistory { + return withFitter(modelParams, (fitter) => { + const systemMessage: ChatMessage = { + role: "system", + content: modelParams.systemPrompt, + }; + fitter.fitRequiredMessage(systemMessage); + + const completionTargetMessage: ChatMessage = { + role: "user", + content: proofGenerationContext.completionTarget, + }; + fitter.fitRequiredMessage(completionTargetMessage); + + const fittedContextTheorems = fitter.fitOptionalObjects( + proofGenerationContext.contextTheorems, + (theorem) => chatItemToContent(theoremToChatItem(theorem)) + ); + const contextTheoremsChat = buildTheoremsChat(fittedContextTheorems); + + return buildAndAnalyzeChat( + fitter, + systemMessage, + contextTheoremsChat, + completionTargetMessage + ); + }); } export function buildProofFixChat( proofGenerationContext: ProofGenerationContext, proofVersions: ProofVersion[], modelParams: ModelParams -): ChatHistory { - const fitter = new ChatTokensFitter( - modelParams.modelName, - modelParams.newMessageMaxTokens, - modelParams.tokensLimit - ); - - const systemMessage: ChatMessage = { - role: "system", - content: modelParams.systemPrompt, - }; - fitter.fitRequiredMessage(systemMessage); - - const completionTargetMessage: ChatMessage = { - role: "user", - content: proofGenerationContext.completionTarget, - }; - const lastProofVersion = proofVersions[proofVersions.length - 1]; - const proofMessage: ChatMessage = { - role: "assistant", - content: lastProofVersion.proof, - }; - const proofFixMessage: ChatMessage = { - role: "user", - content: createProofFixMessage( - lastProofVersion.diagnostic!, - modelParams.multiroundProfile.proofFixPrompt - ), - }; - fitter.fitRequiredMessage(completionTargetMessage); - fitter.fitRequiredMessage(proofMessage); - fitter.fitRequiredMessage(proofFixMessage); - - const fittedProofVersions = fitter.fitOptionalObjects( - proofVersions.slice(0, proofVersions.length - 1), - (proofVersion) => - chatItemToContent(proofVersionToChatItem(proofVersion)) - ); - const previousProofVersionsChat = - buildPreviousProofVersionsChat(fittedProofVersions); - - const fittedContextTheorems = fitter.fitOptionalObjects( - proofGenerationContext.contextTheorems, - (theorem) => chatItemToContent(theoremToChatItem(theorem)) - ); - const contextTheoremsChat = buildTheoremsChat(fittedContextTheorems); - - return buildChat( - systemMessage, - contextTheoremsChat, - completionTargetMessage, - previousProofVersionsChat, - proofMessage, - proofFixMessage - ); +): AnalyzedChatHistory { + return withFitter(modelParams, (fitter) => { + const systemMessage: ChatMessage = { + role: "system", + content: modelParams.systemPrompt, + }; + fitter.fitRequiredMessage(systemMessage); + + const completionTargetMessage: ChatMessage = { + role: "user", + content: proofGenerationContext.completionTarget, + }; + const lastProofVersion = proofVersions[proofVersions.length - 1]; + const proofMessage: ChatMessage = { + role: "assistant", + content: lastProofVersion.proof, + }; + const proofFixMessage: ChatMessage = { + role: "user", + content: createProofFixMessage( + lastProofVersion.diagnostic!, + modelParams.multiroundProfile.proofFixPrompt + ), + }; + fitter.fitRequiredMessage(completionTargetMessage); + fitter.fitRequiredMessage(proofMessage); + fitter.fitRequiredMessage(proofFixMessage); + + const fittedProofVersions = fitter.fitOptionalObjects( + proofVersions.slice(0, proofVersions.length - 1), + (proofVersion) => + chatItemToContent(proofVersionToChatItem(proofVersion)) + ); + const previousProofVersionsChat = + buildPreviousProofVersionsChat(fittedProofVersions); + + const fittedContextTheorems = fitter.fitOptionalObjects( + proofGenerationContext.contextTheorems, + (theorem) => chatItemToContent(theoremToChatItem(theorem)) + ); + const contextTheoremsChat = buildTheoremsChat(fittedContextTheorems); + + return buildAndAnalyzeChat( + fitter, + systemMessage, + contextTheoremsChat, + completionTargetMessage, + previousProofVersionsChat, + proofMessage, + proofFixMessage + ); + }); } -function createProofFixMessage( +export function createProofFixMessage( diagnostic: string, proofFixPrompt: string ): string { @@ -165,6 +189,10 @@ export function buildTheoremsChat(theorems: Theorem[]): ChatHistory { return itemizedChatToHistory(theorems.map(theoremToChatItem)); } +/** + * Note: be careful, the order of the roles should be the opposite (when built as a chat), + * i.e. first goes the proof as `assistant` message and then the diagnostic as a `user` one. + */ export function proofVersionToChatItem( proofVersion: ProofVersion ): UserAssistantChatItem { diff --git a/src/llm/llmServices/utils/chatTokensFitter.ts b/src/llm/llmServices/utils/chatTokensFitter.ts index 937fb850..2ffae105 100644 --- a/src/llm/llmServices/utils/chatTokensFitter.ts +++ b/src/llm/llmServices/utils/chatTokensFitter.ts @@ -1,39 +1,53 @@ import { Tiktoken, TiktokenModel, encoding_for_model } from "tiktoken"; -import { ChatMessage } from "../chat"; +import { ConfigurationError } from "../../llmServiceErrors"; +import { ChatMessage, EstimatedTokens } from "../chat"; export class ChatTokensFitter { readonly tokensLimit: number; private tokens: number = 0; + private encoder: Tiktoken | undefined; private readonly countTokens: (text: string) => number; constructor( - modelName: string, - newMessageMaxTokens: number, - tokensLimit: number + private readonly maxTokensToGenerate: number, + tokensLimit: number, + modelName?: string ) { this.tokensLimit = tokensLimit; - if (this.tokensLimit < newMessageMaxTokens) { - throw Error( - `tokens limit ${this.tokensLimit} is not enough to generate a new message that needs up to ${newMessageMaxTokens}` + if (this.tokensLimit < this.maxTokensToGenerate) { + throw new ConfigurationError( + `tokens limit ${this.tokensLimit} is not enough for the model to generate a new message that needs up to ${maxTokensToGenerate}` ); } - this.tokens += newMessageMaxTokens; + this.tokens += this.maxTokensToGenerate; - let encoder: Tiktoken | undefined = undefined; + this.encoder = undefined; try { - encoder = encoding_for_model(modelName as TiktokenModel); + this.encoder = encoding_for_model(modelName as TiktokenModel); } catch (e) {} this.countTokens = (text: string) => { - if (encoder) { - return encoder.encode(text).length; + if (this.encoder) { + return this.encoder.encode(text).length; } else { - return (text.length / 4) >> 0; + return Math.floor(text.length / 4); } }; } + dispose() { + this.encoder?.free(); + } + + estimateTokens(): EstimatedTokens { + return { + messagesTokens: this.tokens - this.maxTokensToGenerate, + maxTokensToGenerate: this.maxTokensToGenerate, + maxTokensInTotal: this.tokens, + }; + } + fitRequiredMessage(message: ChatMessage) { this.fitRequired(message.content); } @@ -59,7 +73,7 @@ export class ChatTokensFitter { private fitRequired(...contents: string[]) { const contentTokens = this.countContentTokens(...contents); if (this.tokens + contentTokens > this.tokensLimit) { - throw Error( + throw new ConfigurationError( `required content cannot be fitted into tokens limit: '${contents}' require ${contentTokens} + previous ${this.tokens} tokens > max ${this.tokensLimit}` ); } diff --git a/src/llm/llmServices/utils/defaultAvailabilityEstimator.ts b/src/llm/llmServices/utils/defaultAvailabilityEstimator.ts new file mode 100644 index 00000000..cc609d5c --- /dev/null +++ b/src/llm/llmServices/utils/defaultAvailabilityEstimator.ts @@ -0,0 +1,92 @@ +import { LoggerRecord } from "./generationsLogger/loggerRecord"; +import { + Time, + millisToTime, + nowTimestampMillis, + time, + timeToMillis, + timeZero, +} from "./time"; + +export const defaultHeuristicEstimationsMillis = [ + time(1, "second"), + time(5, "second"), + time(10, "second"), + time(30, "second"), + time(1, "minute"), + time(5, "minute"), + time(10, "minute"), + time(30, "minute"), + time(1, "hour"), + time(1, "day"), +].map((time) => timeToMillis(time)); + +/** + * Estimates the expected time for service to become available. + * To do this, analyzes the logs since the last success and computes the time. + * The default algorithm does the following: + * - if the last attempt is successful => don't wait; + * - if there is only one failed attemp => wait 1 second; + * - otherwise, find the maximum time interval between two consistent failures; + * - then, find the first heuristical time estimation that is greater than it; + * - return the difference between this estimation and the time since last attempt + * - (if the time since last attempt is greater => there is no need to wait). + * - P.S. In case the time since last attempt is small enough (<10% of the estimation), + * returns the estimation by itself. + */ +export function estimateTimeToBecomeAvailableDefault( + logsSinceLastSuccess: LoggerRecord[], + nowMillis: number = nowTimestampMillis() +): Time { + const failures = validateInputLogsAreFailures(logsSinceLastSuccess); + + if (failures.length === 0) { + return timeZero; + } + if (failures.length === 1) { + return time(1, "second"); + } + + const intervals: number[] = []; + for (let i = 1; i < failures.length; i++) { + intervals.push( + failures[i].timestampMillis - failures[i - 1].timestampMillis + ); + } + const maxInterval = Math.max(...intervals); + let currentEstimationIndex = 0; + while ( + currentEstimationIndex < defaultHeuristicEstimationsMillis.length - 1 && + maxInterval >= defaultHeuristicEstimationsMillis[currentEstimationIndex] + ) { + currentEstimationIndex++; + } + const currentEstimation = + defaultHeuristicEstimationsMillis[currentEstimationIndex]; + + const timeFromLastAttempt = + nowMillis - failures[failures.length - 1].timestampMillis; + + if (timeFromLastAttempt < currentEstimation) { + // if `timeFromLastAttempt` is small enough, return the estimation by itself + // (so to prevent ugly times, which are very close to the heuristic estimations) + if (timeFromLastAttempt / currentEstimation < 0.1) { + return millisToTime(currentEstimation); + } + return millisToTime(currentEstimation - timeFromLastAttempt); + } + return timeZero; +} + +function validateInputLogsAreFailures( + logsSinceLastSuccess: LoggerRecord[] +): LoggerRecord[] { + for (const record of logsSinceLastSuccess) { + if (record.responseStatus !== "FAILURE") { + throw Error( + `invalid input logs: a non-first record is not a failed one;\n\`${record}\`` + ); + } + } + return logsSinceLastSuccess; +} diff --git a/src/llm/llmServices/utils/generationsLogger/generationsLogger.ts b/src/llm/llmServices/utils/generationsLogger/generationsLogger.ts new file mode 100644 index 00000000..b0753aa2 --- /dev/null +++ b/src/llm/llmServices/utils/generationsLogger/generationsLogger.ts @@ -0,0 +1,185 @@ +import { + GenerationFailedError, + LLMServiceError, +} from "../../../llmServiceErrors"; +import { + LLMServiceRequestFailed, + LLMServiceRequestSucceeded, +} from "../../llmService"; +import { ModelParams } from "../../modelParams"; +import { nowTimestampMillis } from "../time"; + +import { DebugLoggerRecord, LoggedError, LoggerRecord } from "./loggerRecord"; +import { SyncFile } from "./syncFile"; + +export interface GenerationsLoggerSettings { + debug: boolean; + paramsPropertiesToCensor: Object; + cleanLogsOnStart: boolean; +} + +/** + * This class is responsible for logging the actual generations. + * I.e. errors caused by the user or the extension are not the target ones. + * + * The core functionality of `GenerationLogger` is to keep the logs since the last success, + * in order to provide them for the analysis of the time + * needed to `LLMService` to become available again. + * + * Also, due to the `debug` mode, `GenerationLogger` can be used for debug purposes. + * + * *Implementation note:* the `GenerationsLogger` currently writes logs to a file as plain text, + * which could theoretically result in performance overhead. However, in production mode, + * logs are cleaned after each successful generation, keeping the file size small most of the time. + * Thus, the overhead from handling larger files is negligible. + * Although some costs of working with plain text may remain, tests have not shown any performance + * degradation in practice. If performance issues arise in the future, consider modifying the + * string serialization/deserialization within `SyncFile`. + */ +export class GenerationsLogger { + private readonly logsFile: SyncFile; + private static readonly recordsDelim = "@@@ "; + + static readonly censorString = "***censored***"; + + /** + * About settings. + * + * - When `debug` is false, logs only the necessary info: + * timestamp, model name, response status and basic request info (choices and number of tokens sent). + * Logs are being cleaned every time the last request succeeds. + * - When `debug` is true, logs chat history, received completions and params of the model additionally. + * Also, the logs are never cleaned automatically. + * + * `paramsPropertiesToCensor` specifies properties of `ModelParams` (or its extension) + * that will be replaced with the corresponding given values in logs. + * An example `paramsPropertiesToCensor`: `{apiKey: GenerationsLogger.censorString}`. + * To disable censorship, pass an empty object: `{}`. + */ + constructor( + readonly filePath: string, + private readonly settings: GenerationsLoggerSettings + ) { + this.logsFile = new SyncFile(this.filePath); + if (!this.logsFile.exists() || this.settings.cleanLogsOnStart) { + this.resetLogs(); + } + } + + logGenerationSucceeded(request: LLMServiceRequestSucceeded) { + let record = new LoggerRecord( + nowTimestampMillis(), + request.params.modelId, + "SUCCESS", + request.choices, + request.analyzedChat?.estimatedTokens + ); + if (this.settings.debug) { + record = new DebugLoggerRecord( + record, + request.analyzedChat?.chat, + this.censorParamsProperties(request.params), + request.generatedRawProofs + ); + } + + const newLog = `${GenerationsLogger.recordsDelim}${record.serializeToString()}\n`; + if (this.settings.debug) { + this.logsFile.append(newLog); + } else { + this.logsFile.write(newLog); + } + } + + logGenerationFailed(request: LLMServiceRequestFailed) { + let record = new LoggerRecord( + nowTimestampMillis(), + request.params.modelId, + "FAILURE", + request.choices, + request.analyzedChat?.estimatedTokens, + this.toLoggedError( + this.extractAndValidateCause(request.llmServiceError) + ) + ); + if (this.settings.debug) { + record = new DebugLoggerRecord( + record, + request.analyzedChat?.chat, + this.censorParamsProperties(request.params) + ); + } + + const newLog = `${GenerationsLogger.recordsDelim}${record.serializeToString()}\n`; + this.logsFile.append(newLog); + } + + readLogs(): LoggerRecord[] { + const rawData = this.logsFile.read(); + const rawRecords = rawData + .split(GenerationsLogger.recordsDelim) + .slice(1); + return rawRecords.map((rawRecord) => + this.settings.debug + ? DebugLoggerRecord.deserealizeFromString(rawRecord)[0] + : LoggerRecord.deserealizeFromString(rawRecord)[0] + ); + } + + /** + * This method returns logs since the last success exclusively! + * In other words, the last success record (if it exists) is not included in the result. + */ + readLogsSinceLastSuccess(): LoggerRecord[] { + const records = this.readLogs(); + const invertedRow = []; + for (let i = records.length - 1; i >= 0; i--) { + if (records[i].responseStatus === "SUCCESS") { + break; + } + invertedRow.push(records[i]); + } + return invertedRow.reverse(); + } + + /** + * Clears the logs file or creates it if it doesn't exist. + */ + resetLogs() { + this.logsFile.createReset(); + } + + dispose() { + this.logsFile.delete(); + } + + private extractAndValidateCause(llmServiceError: LLMServiceError): Error { + if (!(llmServiceError instanceof GenerationFailedError)) { + throw Error( + `\`GenerationsLogger\` is capable of logging only generation errors, but got: "${this.toLoggedError(llmServiceError)}"` + ); + } + const cause = llmServiceError.cause; + if (cause instanceof LLMServiceError) { + throw Error( + `received doubled-wrapped error to log, cause is instance of \`LLMServiceError\`: "${this.toLoggedError(llmServiceError)}"` + ); + } + return cause; + } + + private censorParamsProperties(params: T): T { + // no need in deep copies, but we shall not overwrite original params + return { + ...params, + ...this.settings.paramsPropertiesToCensor, + }; + } + + private toLoggedError(error: Error): LoggedError { + return { + typeName: error.name, + message: error.message, + }; + } +} diff --git a/src/llm/llmServices/utils/generationsLogger/loggerRecord.ts b/src/llm/llmServices/utils/generationsLogger/loggerRecord.ts new file mode 100644 index 00000000..554a9796 --- /dev/null +++ b/src/llm/llmServices/utils/generationsLogger/loggerRecord.ts @@ -0,0 +1,465 @@ +import { ChatHistory, ChatRole, EstimatedTokens } from "../../chat"; +import { ModelParams } from "../../modelParams"; + +export type ResponseStatus = "SUCCESS" | "FAILURE"; + +export interface LoggedError { + typeName: string; + message: string; +} + +export class ParsingError extends Error { + constructor(message: string, rawParsingData: string) { + const parsingDataInfo = `\n>> \`${rawParsingData}\``; + super(`failed to parse log record: ${message}${parsingDataInfo}`); + } +} + +export class LoggerRecord { + /** + * Even though this value is in millis, effectively it represents seconds. + * I.e. this value is always floored to the seconds (`value % 1000 === 0`). + * + * The reason is that, unfortunately, the current serialization-deserialization + * cycle neglects milliseconds. + */ + readonly timestampMillis: number; + + protected static floorMillisToSeconds(millis: number): number { + return millis - (millis % 1000); + } + + constructor( + timestampMillis: number, + readonly modelId: string, + readonly responseStatus: ResponseStatus, + readonly choices: number, + readonly estimatedTokens: EstimatedTokens | undefined = undefined, + readonly error: LoggedError | undefined = undefined + ) { + this.timestampMillis = + LoggerRecord.floorMillisToSeconds(timestampMillis); + } + + protected static readonly introLinePattern = + /^\[(.*)\] `(.*)` model: (.*)$/; + + protected static readonly loggedErrorHeader = "! error occurred:"; + protected static readonly loggedErrorPattern = + /^! error occurred: \[(.*)\] "(.*)"$/; + + protected static readonly choicesPattern = /^- requested choices: (.*)$/; + + protected static readonly requestTokensHeader = "- request's tokens:"; + protected static readonly requestTokensPattern = + /^- request's tokens: ([0-9]+) = ([0-9]+) \(chat messages\) \+ ([0-9]+) \(max to generate\)$/; + + serializeToString(): string { + const introInfo = this.buildStatusLine(); + const errorInfo = this.buildErrorInfo(); + const requestInfo = this.buildRequestInfo(); + return `${introInfo}${errorInfo}${requestInfo}`; + } + + toString(): string { + return this.serializeToString(); + } + + protected buildStatusLine(): string { + const timestamp = new Date(this.timestampMillis).toLocaleString(); + return `[${timestamp}] \`${this.modelId}\` model: ${this.responseStatus}\n`; + } + + protected buildErrorInfo(): string { + if (this.error === undefined) { + return ""; + } + return `${LoggerRecord.loggedErrorHeader} [${this.error.typeName}] "${LoggerRecord.escapeNewlines(this.error.message)}"\n`; + } + + protected buildRequestInfo(): string { + const choicesRequested = `- requested choices: ${this.choices}\n`; + const requestTokens = + this.estimatedTokens !== undefined + ? `${LoggerRecord.requestTokensHeader} ${this.estimatedTokensToString()}\n` + : ""; + return `${choicesRequested}${requestTokens}`; + } + + private estimatedTokensToString(): string { + return `${this.estimatedTokens!.maxTokensInTotal} = ${this.estimatedTokens!.messagesTokens} (chat messages) + ${this.estimatedTokens!.maxTokensToGenerate} (max to generate)`; + } + + protected static escapeNewlines(text: string): string { + return text.replace("\n", "\\n"); + } + + static deserealizeFromString(rawRecord: string): [LoggerRecord, string] { + const [rawTimestamp, modelId, rawResponseStatus, afterIntroRawRecord] = + this.parseFirstLineByRegex( + this.introLinePattern, + rawRecord, + "intro line" + ); + const timestampMillis = this.parseTimestampMillis(rawTimestamp); + const responseStatus = this.parseAsType( + rawResponseStatus, + "response status" + ); + + const [error, afterLoggedErrorRawRecord] = this.parseOptional( + this.loggedErrorHeader, + (text) => this.parseLoggedError(text), + afterIntroRawRecord + ); + + const [rawChoices, afterChoicesRawRecord] = this.parseFirstLineByRegex( + this.choicesPattern, + afterLoggedErrorRawRecord, + "requested choices" + ); + const [estimatedTokens, afterTokensRawRecord] = this.parseOptional( + this.requestTokensHeader, + (text) => this.parseRequestTokens(text), + afterChoicesRawRecord + ); + + return [ + new LoggerRecord( + timestampMillis, + modelId, + responseStatus, + this.parseIntValue(rawChoices, "requested choices"), + estimatedTokens, + error + ), + afterTokensRawRecord, + ]; + } + + private static parseLoggedError(text: string): [LoggedError, string] { + const [errorTypeName, rawErrorMessage, restRawRecord] = + this.parseFirstLineByRegex( + this.loggedErrorPattern, + text, + "logged error" + ); + return [ + { + typeName: errorTypeName, + message: LoggerRecord.unescapeNewlines(rawErrorMessage), + }, + restRawRecord, + ]; + } + + private static parseRequestTokens(text: string): [EstimatedTokens, string] { + let [ + maxTokensInTotal, + messagesTokens, + maxTokensToGenerate, + restRawRecord, + ] = this.parseFirstLineByRegex( + this.requestTokensPattern, + text, + "request's tokens header" + ); + return [ + { + messagesTokens: this.parseIntValue( + messagesTokens, + "messages tokens" + ), + maxTokensToGenerate: this.parseIntValue( + maxTokensToGenerate, + "max tokens to generate" + ), + maxTokensInTotal: this.parseIntValue( + maxTokensInTotal, + "max tokens in total" + ), + }, + restRawRecord, + ]; + } + + protected static parseOptional( + header: string, + parse: (text: string) => [T, string], + text: string + ): [T | undefined, string] { + if (!text.startsWith(header)) { + return [undefined, text]; + } + return parse(text); + } + + protected static splitByFirstLine(text: string): [string, string] { + const firstLineEndIndex = text.indexOf("\n"); + if (firstLineEndIndex === -1) { + throw new ParsingError("line expected", text); + } + return [ + text.substring(0, firstLineEndIndex), + text.substring(firstLineEndIndex + 1), + ]; + } + + protected static parseAsType(rawValue: string, valueName: string): T { + const parsedValue = rawValue as T; + if (parsedValue === null) { + throw new ParsingError(`invalid ${valueName}`, rawValue); + } + return parsedValue; + } + + protected static parseTimestampMillis(rawTimestamp: string): number { + try { + return new Date(rawTimestamp).getTime(); + } catch (e) { + throw new ParsingError("invalid timestampt", rawTimestamp); + } + } + + protected static parseIntValue( + rawValue: string, + valueName: string + ): number { + try { + return parseInt(rawValue); + } catch (e) { + throw new ParsingError(`invalid ${valueName}`, rawValue); + } + } + + protected static parseIntValueOrUndefined( + rawValue: string, + valueName: string + ): number | undefined { + if (rawValue === "undefined") { + return undefined; + } + return this.parseIntValue(rawValue, valueName); + } + + protected static parseByRegex( + pattern: RegExp, + text: string, + valueName: string + ): string[] { + const match = text.match(pattern); + if (!match) { + throw new ParsingError(`invalid ${valueName}`, text); + } + return match.slice(1); + } + + protected static parseFirstLineByRegex( + pattern: RegExp, + text: string, + valueName: string + ): string[] { + const [firstLine, restText] = this.splitByFirstLine(text); + const parsedLine = this.parseByRegex(pattern, firstLine, valueName); + return [...parsedLine, restText]; + } + + protected static unescapeNewlines(text: string): string { + return text.replace("\\n", "\n"); + } +} + +export class DebugLoggerRecord extends LoggerRecord { + constructor( + baseRecord: LoggerRecord, + readonly chat: ChatHistory | undefined, + readonly params: ModelParams, + readonly generatedProofs: string[] | undefined = undefined + ) { + super( + baseRecord.timestampMillis, + baseRecord.modelId, + baseRecord.responseStatus, + baseRecord.choices, + baseRecord.estimatedTokens, + baseRecord.error + ); + } + + protected static readonly subItemIndent = "\t"; + protected static readonly subItemDelimIndented = `${this.subItemIndent}> `; + protected static readonly jsonStringifyIndent = 2; + + protected static readonly emptyListLine = `${this.subItemIndent}~ empty`; + protected static readonly emptyListPattern = /^\t~ empty$/; + + protected static readonly chatHeader = "- chat sent:"; + protected static readonly chatHeaderPattern = /^- chat sent:$/; + protected static readonly chatMessagePattern = /^\t> \[(.*)\]: `(.*)`$/; + + protected static readonly generatedProofsHeader = "- generated proofs:"; + protected static readonly generatedProofsHeaderPattern = + /^- generated proofs:$/; + protected static readonly generatedProofPattern = /^\t> `(.*)`$/; + + protected static readonly paramsHeader = "- model's params:"; + protected static readonly paramsHeaderPattern = /^- model's params:$/; + + serializeToString(): string { + const baseInfo = super.serializeToString(); + const extraInfo = this.buildExtraInfo(); + return `${baseInfo}${extraInfo}`; + } + + private buildExtraInfo(): string { + const chatInfo = + this.chat !== undefined + ? `${DebugLoggerRecord.chatHeader}\n${this.chatToExtraLogs()}\n` + : ""; + const generatedProofs = + this.generatedProofs !== undefined + ? `${DebugLoggerRecord.generatedProofsHeader}\n${this.proofsToExtraLogs()}\n` + : ""; + const paramsInfo = `${DebugLoggerRecord.paramsHeader}\n${this.paramsToExtraLogs()}\n`; + return `${chatInfo}${generatedProofs}${paramsInfo}`; + } + + private chatToExtraLogs(): string { + return this.chat!.length === 0 + ? DebugLoggerRecord.emptyListLine + : this.chat!.map( + (message) => + `${DebugLoggerRecord.subItemDelimIndented}[${message.role}]: \`${LoggerRecord.escapeNewlines(message.content)}\`` + ).join("\n"); + } + + private proofsToExtraLogs(): string { + return this.generatedProofs!.length === 0 + ? DebugLoggerRecord.emptyListLine + : this.generatedProofs!.map( + (proof) => + `${DebugLoggerRecord.subItemDelimIndented}\`${LoggerRecord.escapeNewlines(proof)}\`` + ).join("\n"); + } + + private paramsToExtraLogs(): string { + return JSON.stringify( + this.params, + null, + DebugLoggerRecord.jsonStringifyIndent + ); + } + + static deserealizeFromString( + rawRecord: string + ): [DebugLoggerRecord, string] { + const [baseRecord, afterBaseRawRecord] = super.deserealizeFromString( + rawRecord + ); + const [chat, afterChatRawRecord] = this.parseOptional( + this.chatHeader, + (text) => this.parseChatHistory(text), + afterBaseRawRecord + ); + const [generatedProofs, afterProofsRawRecord] = this.parseOptional( + this.generatedProofsHeader, + (text) => this.parseGeneratedProofs(text), + afterChatRawRecord + ); + const [params, unparsedData] = + this.parseModelParams(afterProofsRawRecord); + + return [ + new DebugLoggerRecord(baseRecord, chat, params, generatedProofs), + unparsedData, + ]; + } + + private static parseChatHistory(text: string): [ChatHistory, string] { + let [restRawRecord] = this.parseFirstLineByRegex( + this.chatHeaderPattern, + text, + "chat history header" + ); + const chat: ChatHistory = []; + if (restRawRecord.startsWith(this.emptyListLine)) { + return [ + chat, + this.parseFirstLineByRegex( + this.emptyListPattern, + restRawRecord, + "empty chat history keyword" + )[0], + ]; + } + while (restRawRecord.startsWith(this.subItemDelimIndented)) { + const [rawRole, rawContent, newRestRawRecord] = + this.parseFirstLineByRegex( + this.chatMessagePattern, + restRawRecord, + "chat history's message" + ); + chat.push({ + role: this.parseAsType(rawRole, "chat role"), + content: this.unescapeNewlines(rawContent), + }); + restRawRecord = newRestRawRecord; + } + return [chat, restRawRecord]; + } + + private static parseGeneratedProofs(text: string): [string[], string] { + let [restRawRecord] = this.parseFirstLineByRegex( + this.generatedProofsHeaderPattern, + text, + "generated proofs header" + ); + const generatedProofs: string[] = []; + if (restRawRecord.startsWith(this.emptyListLine)) { + return [ + generatedProofs, + this.parseFirstLineByRegex( + this.emptyListPattern, + restRawRecord, + "empty generated proofs keyword" + )[0], + ]; + } + while (restRawRecord.startsWith(this.subItemDelimIndented)) { + const [rawGeneratedProof, newRestRawRecord] = + this.parseFirstLineByRegex( + this.generatedProofPattern, + restRawRecord, + "generated proof" + ); + generatedProofs.push(this.unescapeNewlines(rawGeneratedProof)); + restRawRecord = newRestRawRecord; + } + return [generatedProofs, restRawRecord]; + } + + private static parseModelParams(text: string): [ModelParams, string] { + let [restRawRecord] = this.parseFirstLineByRegex( + this.paramsHeaderPattern, + text, + "model's params header" + ); + const params = this.parseAsType( + JSON.parse(restRawRecord), + "model's params" + ); + + restRawRecord = restRawRecord.slice( + JSON.stringify(params, null, this.jsonStringifyIndent).length + ); + if (!restRawRecord.startsWith("\n")) { + throw new ParsingError( + `invalid model's params suffix`, + restRawRecord + ); + } + restRawRecord = restRawRecord.slice(1); + + return [params, restRawRecord]; + } +} diff --git a/src/llm/llmServices/utils/generationsLogger/syncFile.ts b/src/llm/llmServices/utils/generationsLogger/syncFile.ts new file mode 100644 index 00000000..ddd03486 --- /dev/null +++ b/src/llm/llmServices/utils/generationsLogger/syncFile.ts @@ -0,0 +1,49 @@ +import * as fs from "fs"; +import * as path from "path"; + +/** + * Since `SyncFile` methods are not `async`, + * they are expected to be effectively "synchronized". + * This means that despite the concurrent nature of some parts of the system + * (for example, completing several "admit"-s concurrently), + * this class indeed provides a concurrent-safe way to work with a file. + */ +export class SyncFile { + constructor( + public readonly filePath: string, + public readonly encoding: string = "utf-8" + ) {} + + exists(): boolean { + return fs.existsSync(this.filePath); + } + + write(data: string) { + fs.writeFileSync( + this.filePath, + data, + this.encoding as fs.WriteFileOptions + ); + } + + append(data: string) { + fs.appendFileSync( + this.filePath, + data, + this.encoding as fs.WriteFileOptions + ); + } + + read(): string { + return fs.readFileSync(this.filePath, this.encoding as BufferEncoding); + } + + createReset() { + fs.mkdirSync(path.dirname(this.filePath), { recursive: true }); + fs.writeFileSync(this.filePath, ""); + } + + delete() { + fs.unlinkSync(this.filePath); + } +} diff --git a/src/llm/llmServices/utils/modelParamsAccessors.ts b/src/llm/llmServices/utils/modelParamsAccessors.ts new file mode 100644 index 00000000..e9fc1563 --- /dev/null +++ b/src/llm/llmServices/utils/modelParamsAccessors.ts @@ -0,0 +1,5 @@ +import { ModelParams } from "../modelParams"; + +export function modelName(params: ModelParams): string | undefined { + return "modelName" in params ? (params.modelName as string) : ""; +} diff --git a/src/llm/llmServices/utils/paramsResolvers/abstractResolvers.ts b/src/llm/llmServices/utils/paramsResolvers/abstractResolvers.ts new file mode 100644 index 00000000..9a915907 --- /dev/null +++ b/src/llm/llmServices/utils/paramsResolvers/abstractResolvers.ts @@ -0,0 +1,134 @@ +/** + * Represents the identifier of `T`'s property that can be used to access + * the value of the property. That is, in a sense, just the name of the property. + */ +export type PropertyKey = keyof T; + +/** + * Core interface for an object that is capable of resolving `InputType` into `ResolveToType`. + * + * The `_resolverId` member serves as a discriminator for the `ParamsResolver` type. + */ +export interface ParamsResolver { + resolve(inputParams: InputType): ParamsResolutionResult; + + /** + * Should be set to "ParamsResolver" value in any implementation object. + */ + _resolverId: "ParamsResolver"; +} + +/** + * Contains both the resolved parameters object `resolved` and the resolution logs `resolutionLogs`. + * If resolution does not succeed, `resolved` is undefined. + */ +export interface ParamsResolutionResult { + resolved?: ResolveToType; + resolutionLogs: SingleParamResolutionResult[]; +} + +/** + * Interface for an object capable of resolving a single parameter. + * + * Practically, it can be easily represented as `ParamsResolver`. + * However, since interfaces cannot have a default method implementation, + * this logic is only available through `AbstractSingleParamResolver`, which + * implements both interfaces. + */ +export interface SingleParamResolver { + resolveParam(inputParams: InputType): SingleParamResolutionResult; +} + +/** + * Interface that stores information about the resolution of a single parameter. + */ +export interface SingleParamResolutionResult { + /** + * Undefined if the parameter was overriden with a mock value. + * Otherwise contains the name of the input parameter to resolve. + */ + inputParamName?: string; + + /** + * Contains the resulting parameter value after its resolution + * if successful. On failure, `resultValue` is undefined. + */ + resultValue?: T; + + /** + * If resolution fails, contains the message explaining the cause. + * Otherwise, it is undefined. + */ + isInvalidCause?: string; + + /** + * `inputReadCorrectly.wasPerformed` is true iff the parameter's input value + * is read as a defined value (of the correct type, if it is verifiable). + */ + inputReadCorrectly: ResolutionActionResult; + + /** + * `overriden.wasPerformed` is true iff the parameter's input value is overriden with a new value. + * I.e. if an override attempt is made with the same value as the input, + * `overriden.wasPerformed` will be false. + */ + overriden: ResolutionActionDetailedResult; + + /** + * `resolvedWithDefault.wasPerformed` is true iff the default resolver is called, + * i.e. iff the parameter's value is undefined after input read and potential override. + * Note: even if the default resolver returned undefined, `resolvedWithDefault.wasPerformed` will be true. + */ + resolvedWithDefault: ResolutionActionResult; +} + +/** + * `withValue` is set only if `wasPerformed` is true. However, it can be set with undefined too. + */ +export interface ResolutionActionResult { + wasPerformed: boolean; + withValue?: T; +} + +/** + * The same as `ResolutionActionResult`, but provides an explanation message + * if the action was performed. + */ +export interface ResolutionActionDetailedResult + extends ResolutionActionResult { + message?: string; +} + +/** + * Checks whether `object` is of `ParamsResolver` type. + * To do this, the implementation uses the `ParamsResolver._resolverId` discriminator only. + */ +export function isParamsResolver( + object: any +): object is ParamsResolver { + return object._resolverId === "ParamsResolver"; +} + +/** + * This abstract class extends the `SingleParamResolver` implementation + * with a default `resolve` method implementation to implement the `ParamsResolver` interface. + * + * If you plan to implement `SingleParamResolver`, you most likely want to extend this class. + */ +export abstract class AbstractSingleParamResolver + implements SingleParamResolver, ParamsResolver +{ + abstract resolveParam( + inputParams: InputType + ): SingleParamResolutionResult; + + resolve(inputParams: InputType): ParamsResolutionResult { + const paramResolutionResult = this.resolveParam(inputParams); + return { + resolved: paramResolutionResult.resultValue, + resolutionLogs: [paramResolutionResult], + }; + } + + _resolverId: "ParamsResolver" = "ParamsResolver"; +} diff --git a/src/llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers.ts b/src/llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers.ts new file mode 100644 index 00000000..ac255f3d --- /dev/null +++ b/src/llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers.ts @@ -0,0 +1,90 @@ +import { + UserModelParams, + UserMultiroundProfile, +} from "../../../userModelParams"; +import { + ModelParams, + MultiroundProfile, + multiroundProfileSchema, +} from "../../modelParams"; + +import { ValidationRules } from "./builders"; +import { + ParamsResolverImpl, + ValidParamsResolverImpl, +} from "./paramsResolverImpl"; + +export class BasicMultiroundProfileResolver + extends ParamsResolverImpl + implements + ValidParamsResolverImpl +{ + constructor() { + super(multiroundProfileSchema, "MultiroundProfile"); + } + + readonly maxRoundsNumber = this.resolveParam("maxRoundsNumber") + .default(() => defaultMultiroundProfile.maxRoundsNumber) + .validate(ValidationRules.bePositiveNumber); + + readonly defaultProofFixChoices = this.resolveParam( + "proofFixChoices" + ) + .default(() => defaultMultiroundProfile.defaultProofFixChoices) + .validate(ValidationRules.bePositiveNumber); + + readonly proofFixPrompt = this.resolveParam("proofFixPrompt") + .default(() => defaultMultiroundProfile.proofFixPrompt) + .noValidationNeeded(); +} + +/** + * Properties of `defaultMultiroundProfile` can be used separately. + * - Multiround is disabled by default. + * - 1 fix version per proof by default. + * - Default `proofFixPrompt` includes `${diagnostic}` message. + */ +export const defaultMultiroundProfile: MultiroundProfile = { + maxRoundsNumber: 1, + defaultProofFixChoices: 1, + proofFixPrompt: + "Unfortunately, the last proof is not correct. Here is the compiler's feedback: `${diagnostic}`. Please, fix the proof.", +}; + +export class BasicModelParamsResolver< + InputType extends UserModelParams, + ResolveToType extends ModelParams, + > + extends ParamsResolverImpl + implements ValidParamsResolverImpl +{ + readonly modelId = this.resolveParam("modelId") + .requiredToBeConfigured() + .noValidationNeeded(); + + readonly systemPrompt = this.resolveParam("systemPrompt") + .default(() => defaultSystemMessageContent) + .noValidationNeeded(); + + readonly maxTokensToGenerate = this.resolveParam( + "maxTokensToGenerate" + ) + .requiredToBeConfigured() + .validate(ValidationRules.bePositiveNumber); + + readonly tokensLimit = this.resolveParam("tokensLimit") + .requiredToBeConfigured() + .validate(ValidationRules.bePositiveNumber); + + readonly multiroundProfile = this.resolveNestedParams( + "multiroundProfile", + new BasicMultiroundProfileResolver() + ); + + readonly defaultChoices = this.resolveParam("choices") + .requiredToBeConfigured() + .validate(ValidationRules.bePositiveNumber); +} + +export const defaultSystemMessageContent: string = + "Generate proof of the theorem from user input in Coq. You should only generate proofs in Coq. Never add special comments to the proof. Your answer should be a valid Coq proof. It should start with 'Proof.' and end with 'Qed.'."; diff --git a/src/llm/llmServices/utils/paramsResolvers/builders.ts b/src/llm/llmServices/utils/paramsResolvers/builders.ts new file mode 100644 index 00000000..f8116b06 --- /dev/null +++ b/src/llm/llmServices/utils/paramsResolvers/builders.ts @@ -0,0 +1,493 @@ +import { + stringifyAnyValue, + stringifyDefinedValue, +} from "../../../../utils/printers"; + +import { AbstractSingleParamResolver, PropertyKey } from "./abstractResolvers"; +import { SingleParamResolutionResult } from "./abstractResolvers"; + +/** + * Facade function to create a single-parameter-resolver builder. + * Then one can define the resolution of the `inputParamKey` property + * of the input `InputType` object. + */ +export function resolveParam( + inputParamKey: PropertyKey +): SingleParamResolverBuilder { + return new SingleParamResolverBuilderImpl(inputParamKey); +} + +/** + * Facade function to create a single-parameter-resolver builder. + * Then one can define the resolution of the new property based on + * the input `InputType` object. + */ +export function newParam( + valueBuilder: StrictValueBuilder +): SingleParamWithValueResolverBuilder { + return new SingleParamWithValueResolverBuilderImpl( + undefined, + { valueBuilder: valueBuilder }, + undefined + ); +} + +/** + * First single-parameter-resolver builder interface. + * It allows to define whether the parameter's value should be overriden, + * resolved with default, or just required to be configured. + */ +export interface SingleParamResolverBuilder { + /** + * Specifies that the parameter's value should be overriden. + * Override actually happens only if the constructed value + * differs from the input read one. + * + * Implementation notes: + * - any conditions can be specified inside `valueBuilder`, + * so it is possible to make override conditional; + * - `valueBuilder` can return undefined, forcing the following + * resolution with default. + * + * @param valueBuilder lambda to build the value to override with. + * @param paramMessage optional message to explain the override to the user. + */ + override( + valueBuilder: ValueBuilder, + paramMessage?: Message + ): SingleParamResolverBuilder; + + /** + * Have the same functionality as `override` method, + * but describes overriding with a mock value, i.e. the value + * that will never be actually used. Thus, this override always succeeds + * and cannot return undefined. + * + * Also, this override is never shown in logs (`overriden.wasPerformed` is false), + * since mock resolutions should not be tracked the same as real once. + * *TODO: support mock resolutions in logs with a separate property.* + */ + overrideWithMock( + valueBuilder: StrictValueBuilder + ): AbstractSingleParamResolver; + + /** + * Specifies that the parameter's value should be resolved with default if it is undefined. + * Resolution with default is performed only if the previous steps result in the parameter's value + * to be undefined (i.e. input was initially undefined and not overriden / was overriden with `undefined`). + * + * Implementation notes: + * - any conditions can be specified inside `valueBuilder`, + * so it is possible to make resolution with default conditional; + * - `valueBuilder` can return undefined, forcing the parameter's resolution to fail + * (since default resolution is the last step). + * + * @param valueBuilder lambda to build the default value. + * @param helpMessageIfNotResolved optional message to show to the user if `valueBuilder` have not built a default value for the parameter (i.e. returned `undefined`). + */ + default( + valueBuilder: ValueBuilder, + helpMessageIfNotResolved?: Message + ): SingleParamWithValueResolverBuilder; + + /** + * Specifies that the parameter's value is required to be configured in the input `InputType` object, + * i.e. does not have any value resolvers. Practically, this method just skips the step of specifying the value resolvers. + */ + requiredToBeConfigured(): SingleParamWithValueResolverBuilder; +} + +/** + * Second single-parameter-resolver builder interface. + * It allows to define validation rules for the resolved parameter's value. + */ +export interface SingleParamWithValueResolverBuilder { + /** + * Specifies validation rules to verify the resolved parameter's value with. + * The rules will be checked in the order they appear in the arguments. + */ + validate( + ...validationRules: ValidationRule[] + ): AbstractSingleParamResolver; + + /** + * Specifies that the resolved parameter's value does not need validation. + * Practically, this method just skips the step of specifying the validation rules. + */ + noValidationNeeded(): AbstractSingleParamResolver; + + /** + * Specifies that the resolved parameter's value does need validation, + * but it will be performed by the caller code at runtime only. + * Practically, this method just skips the step of specifying the validation rules. + */ + validateAtRuntimeOnly(): AbstractSingleParamResolver; +} + +/** + * Builds the value to resolve the parameter with. + * Return undefined value to skip the corresponding resolution step. + */ +export type ValueBuilder = ( + inputParams: InputType +) => T | undefined; + +/** + * Builds the value to resolve the parameter with. Cannot return undefined values. + */ +export type StrictValueBuilder = (inputParams: InputType) => T; + +/** + * Accepts `inputParams` and builds the message. + */ +export type MessageBuilder = (inputParams: InputType) => string; + +/** + * Represents a message to use in the resolution logs. Can be both a simple string and + * a builder, which builds the message from `inputParams`. + */ +export type Message = string | MessageBuilder; + +function buildMessage( + message: Message | undefined, + inputParams: InputType +): string | undefined { + return typeof message === "function" ? message(inputParams) : message; +} + +/** + * Specification of a single validation rule. + * The first tuple element is a validator function, while the second one + * is the message explaining the check. Generally, the message should be in such a format, + * that the following phrase will make sense: `parameter's value should ${message}`. + */ +export type ValidationRule = [ + Validator, + Message, +]; + +/** + * Predicate to validate the resolved parameter's value with. + */ +export type Validator = ( + value: T, + inputParams: InputType +) => boolean; + +/** + * Namespace that provides some common validation rules. + */ +export namespace ValidationRules { + export const bePositiveNumber: ValidationRule = [ + (value: number) => value > 0, + "be positive", + ]; +} + +/** + * Builders' implementation below >>>>>>>>>>>>>>>>>>>>>> + */ + +interface Overrider { + valueBuilder: ValueBuilder; + explanationMessage?: Message; +} + +interface DefaultResolver { + valueBuilder: ValueBuilder; + noDefaultValueHelpMessage?: Message; +} + +class SingleParamResolverBuilderImpl + implements SingleParamResolverBuilder +{ + private overrider: Overrider | undefined = undefined; + constructor(private readonly inputParamKey: PropertyKey) {} + + override( + valueBuilder: ValueBuilder, + paramMessage?: Message + ): SingleParamResolverBuilder { + if (this.overrider !== undefined) { + throw new Error( + `parameter \'${String(this.inputParamKey)}\'is overriden multiple times` + ); + } + this.overrider = { + valueBuilder: valueBuilder, + explanationMessage: paramMessage, + }; + return this; + } + + overrideWithMock( + valueBuilder: StrictValueBuilder + ): AbstractSingleParamResolver { + return new SingleParamResolverImpl( + this.inputParamKey, + { + valueBuilder: valueBuilder, + }, + undefined, + [], + true + ); + } + + default( + valueBuilder: ValueBuilder, + noDefaultValueHelpMessage?: Message + ): SingleParamWithValueResolverBuilder { + return new SingleParamWithValueResolverBuilderImpl( + this.inputParamKey, + this.overrider, + { + valueBuilder: valueBuilder, + noDefaultValueHelpMessage: noDefaultValueHelpMessage, + } + ); + } + + requiredToBeConfigured(): SingleParamWithValueResolverBuilder< + InputType, + T + > { + return new SingleParamWithValueResolverBuilderImpl( + this.inputParamKey, + this.overrider, + undefined + ); + } +} + +class SingleParamWithValueResolverBuilderImpl + implements SingleParamWithValueResolverBuilder +{ + constructor( + private readonly inputParamKey?: PropertyKey, + private readonly overrider?: Overrider, + private readonly defaultResolver?: DefaultResolver + ) {} + + validate( + ...validationRules: ValidationRule[] + ): AbstractSingleParamResolver { + return new SingleParamResolverImpl( + this.inputParamKey, + this.overrider, + this.defaultResolver, + validationRules + ); + } + + noValidationNeeded(): AbstractSingleParamResolver { + return new SingleParamResolverImpl( + this.inputParamKey, + this.overrider, + this.defaultResolver, + [] + ); + } + + validateAtRuntimeOnly(): AbstractSingleParamResolver { + return this.noValidationNeeded(); + } +} + +class SingleParamResolverImpl extends AbstractSingleParamResolver< + InputType, + T +> { + constructor( + private readonly inputParamKey?: PropertyKey, + private readonly overrider?: Overrider, + private readonly defaultResolver?: DefaultResolver, + private readonly validationRules: ValidationRule[] = [], + private readonly overridenWithMockValue: boolean = false + ) { + super(); + } + + /** + * Unfortunately, since the language does not allow to validate the type of the parameter properly, + * no actual type checking is performed. + */ + resolveParam(inputParams: InputType): SingleParamResolutionResult { + const result: SingleParamResolutionResult = { + inputParamName: + this.inputParamKey === undefined + ? undefined + : String(this.inputParamKey), + resultValue: undefined, + inputReadCorrectly: { + wasPerformed: false, + }, + overriden: { + wasPerformed: false, + }, + resolvedWithDefault: { + wasPerformed: false, + }, + }; + + let value: T | undefined = undefined; + let resultIsComplete = false; + + value = this.tryToReadInputValue(inputParams, result); + + [value, resultIsComplete] = this.tryToResolveWithOverrider( + inputParams, + result, + value + ); + if (resultIsComplete) { + return result; + } + + [value, resultIsComplete] = this.tryToResolveWithDefault( + inputParams, + result, + value + ); + if (resultIsComplete) { + return result; + } + + // failed to resolve value + if (value === undefined) { + result.isInvalidCause = this.noValueMessage(); + return result; + } + + const valueIsValid = this.validateDefinedValue( + inputParams, + result, + value + ); + if (!valueIsValid) { + return result; + } + + result.resultValue = value; + return result; + } + + protected tryToReadInputValue( + inputParams: InputType, + result: SingleParamResolutionResult + ): T | undefined { + if (this.inputParamKey === undefined) { + return undefined; + } + const userValue = inputParams[this.inputParamKey]; + if (userValue === undefined) { + return undefined; + } + // if user specified a value, then take it + const userValueAsT = userValue as T; + if (userValueAsT !== null) { + result.inputReadCorrectly = { + wasPerformed: true, + withValue: userValueAsT, + }; + return userValueAsT; + } else { + // unfortunately, this case is unreachable: TypeScript does not provide the way to check that `userValue` is of the `T` type indeed + throw Error( + `cast of \`any\` to generic \`T\` type should always succeed, value = ${stringifyAnyValue(userValue)} for ${this.quotedName()} parameter` + ); + } + } + + /** + * @returns a new `value` value and true if `result` is complete, false otherwise + */ + private tryToResolveWithOverrider( + inputParams: InputType, + result: SingleParamResolutionResult, + value: T | undefined + ): [T | undefined, boolean] { + if (this.overrider === undefined) { + return [value, false]; + } + const { valueBuilder, explanationMessage } = this.overrider; + const valueToOverrideWith = valueBuilder(inputParams); + if (this.overridenWithMockValue) { + // no checks and logs are needed, just return the mock value + result.resultValue = valueToOverrideWith; + if (valueToOverrideWith === undefined) { + throw Error( + `${this.quotedName()} is expected to be a mock value, but its builder resolved with "undefined"` + ); + } + return [valueToOverrideWith, true]; + } + if (value === valueToOverrideWith) { + return [value, false]; + } + result.overriden = { + wasPerformed: true, + withValue: valueToOverrideWith, + message: buildMessage(explanationMessage, inputParams), + }; + return [valueToOverrideWith, false]; + } + + /** + * @returns a new `value` value and true if `result` is complete, false otherwise + */ + private tryToResolveWithDefault( + inputParams: InputType, + result: SingleParamResolutionResult, + value: T | undefined + ): [T | undefined, boolean] { + // if user value is still undefined after overriden resolution, resolve with default + if (value !== undefined || this.defaultResolver === undefined) { + return [value, false]; + } + const { valueBuilder, noDefaultValueHelpMessage } = + this.defaultResolver; + value = valueBuilder(inputParams); + result.resolvedWithDefault = { + wasPerformed: true, + withValue: value, + }; + // failed to resolve value because default value was not found (but could potentially) + if (value === undefined) { + const helpMessageSuffix = + noDefaultValueHelpMessage === undefined + ? "" + : `. ${buildMessage(noDefaultValueHelpMessage, inputParams)}`; + result.isInvalidCause = `${this.noValueMessage()}${helpMessageSuffix}`; + return [value, true]; + } + return [value, false]; + } + + /** + * @returns true if `value` is valid, false otherwise + */ + private validateDefinedValue( + inputParams: InputType, + result: SingleParamResolutionResult, + value: T + ): boolean { + for (const [validateValue, paramShouldMessage] of this + .validationRules) { + const validationResult = validateValue(value, inputParams); + if (!validationResult) { + result.isInvalidCause = `${this.quotedName()} should ${buildMessage(paramShouldMessage, inputParams)}, but has value ${stringifyDefinedValue(value)}`; + return false; + } + } + return true; + } + + private quotedName(): string { + return `\`${String(this.inputParamKey)}\``; + } + + private noValueMessage(): string { + return `${this.quotedName()} is required, but neither a user value nor a default one is specified`; + } +} diff --git a/src/llm/llmServices/utils/paramsResolvers/paramsResolverImpl.ts b/src/llm/llmServices/utils/paramsResolvers/paramsResolverImpl.ts new file mode 100644 index 00000000..266d4b4e --- /dev/null +++ b/src/llm/llmServices/utils/paramsResolvers/paramsResolverImpl.ts @@ -0,0 +1,240 @@ +import { DefinedError, JSONSchemaType, ValidateFunction } from "ajv"; + +import { + AjvMode, + ajvErrorsAsString, + buildAjv, +} from "../../../../utils/ajvErrorsHandling"; + +import { + ParamsResolutionResult, + ParamsResolver, + isParamsResolver, +} from "./abstractResolvers"; +import { SingleParamResolutionResult } from "./abstractResolvers"; +import { PropertyKey } from "./abstractResolvers"; +import { + SingleParamResolverBuilder, + SingleParamWithValueResolverBuilder, + StrictValueBuilder, + newParam, + resolveParam, +} from "./builders"; + +/** + * Generic type check that returns `any` if `T` has no optional properties, + * i.e. properties that could be set with `undefined` value. + * Otherwise, returns `never`. + * + * This check is supposed to be used as follows: `T extends NoOptionalProperties`. + * However, it is already used to restrict `ResolveToType` of `ParamsResolverImpl`, + * so most likely you don't need to use it by yourself. + */ +export type NoOptionalProperties = [ + { + [K in keyof T]-?: undefined extends T[K] ? any : never; + }[keyof T], +] extends [never] + ? any + : never; + +/** + * Implement this type every time you develop a new `ParamsResolverImpl`. + * It checks that you have specified the correct resolvers for all `ResolveToType` properties. + * + * Unfortunately, it can only be used for statically known types, so the base `ParamsResolverImpl` class cannot extend it. + */ +export type ValidParamsResolverImpl = { + [K in keyof ResolveToType]: ParamsResolver; +}; + +/** + * The base class that implements the parameters resolving algorithm and provides + * a convenient way to declare your custom parameters resolvers. + * + * How to use it. + * 1. Declare a custom class that extends `ParamsResolverImpl` with the right generic types: + * the `InputType` is the type of the input object to resolve and + * the `ResolveToType` is the type of the output resolved object. + * + * 2. Specify each property resolver in the following format: + * ``` + * readonly resolveToParamKey = this.resolveParam(inputParamKey) + * . // continue building the single parameter resolver with the hints + * ``` + * Notes: + * * `resolveToParamKey` should be a name of one of the `ResolveToType` properties (and should not start with "_"); + * * `inputParamKey` should be a name of one of the `InputType` properties; + * * you can also use `this.insertParam(...)` and `resolveNestedParams(...)` + * instead of `this.resolveParam` to start building the parameter resolver. Check their docs for more details. + * + * Once you finish building the parameter resolver, it should implement `ParamsResolver` + * (i.e. should be a finished parameter resolver) — most likely, it will be of + * the `AbstractSingleParamResolver` class. + * + * Implementation note: + * * if you need to declare any utility properties in your class, + * it is okay to do by starting the proeprty name with an underscore (for example, `_utilityProp`). + * + * 3. Check that all resolvers for the `ResolveToType` properties are specified correctly. To do this, make your + * parameters resolver class implement `ValidParamsResolverImpl`. It will check exactly this contract. + * + * 4. Specify an Ajv JSON schema for `ResolveToType` and pass it to the `super(...)` constructor. So far it is needed to + * properly validate the properties inside the resulting resolved object. *TODO: make `ParamsResolverImpl` generate such a schema.* + * + * 5. Call `resolve(inputParams: InputType): ParamsResolutionResult` method of your custom class + * to validate the `inputParams` parameters robustly and efficiently. The parameters resolver has no visible inner state, + * so can be called multiple times with no side-effects. + */ +export abstract class ParamsResolverImpl< + InputType, + ResolveToType extends NoOptionalProperties, +> implements ParamsResolver +{ + private readonly _resolveToTypeValidator: ValidateFunction; + protected readonly _resolveToTypeName: string; + + constructor( + resolvedParamsSchema: JSONSchemaType, + resolveToTypeName: string + ) { + this._resolveToTypeName = resolveToTypeName; + this._resolveToTypeValidator = buildAjv( + AjvMode.COLLECT_ALL_ERRORS + ).compile(resolvedParamsSchema) as ValidateFunction; + } + + /** + * Creates a builder of the resolver that resolves `InputType[inputParamKey]`. + */ + protected resolveParam( + inputParamKey: PropertyKey + ): SingleParamResolverBuilder { + return resolveParam(inputParamKey); + } + + /** + * Creates a builder of the resolver that resolves a new property. + * In other words, this is the way to build the property of the `ResolveToType` object + * that does not have a corresponding property inside `InputType`. + */ + protected insertParam( + valueBuilder: StrictValueBuilder + ): SingleParamWithValueResolverBuilder { + return newParam(valueBuilder); + } + + /** + * Similarly to the `resolveParam` method, creates a resolver that resolves `InputType[inputParamKey]`. + * However, `inputParamKey` will be resolved via the standalone `nestedParamsResolver` resolver. + * Practically, this is the proper way to resolve the properties of the nested object separately, + * instead of resolving them all together as a single parameter having the nested object value. + * + * Moreover, since this method returns a finished resolver instead of a builder, + * so far there is no way to declare a whole-nested-object resolution rules (overrides and defaults) + * at the same time with using a standalone parameters resolver for the nested parameters. + * Thus, you need to choose whether to resolve the nested object as a single object or parameter-by-parameter. + * + * Notes on the nested parameters resolution algorithm. + * - If `InputType[inputParamKey]` turns out to be undefined, an empty object `{}` will be passed to + * the `nestedParamsResolver.resolve` method as the input. + * - All names of the parameters of the nested `ParamInputType` will be prepended with `${inputParamKey}.` + * in the resolution logs. + * - The resolution logs of the nested `ParamInputType` properties will be merged with the resolution logs + * of the other `InputType` properties and returned in the `resolve` method. + */ + protected resolveNestedParams( + inputParamKey: PropertyKey, + nestedParamsResolver: ParamsResolver + ): ParamsResolver { + return new (class { + resolve(inputParams: InputType): ParamsResolutionResult { + const paramInputValue = (inputParams[inputParamKey] ?? + {}) as ParamInputType; + const paramResolutionResult = + nestedParamsResolver.resolve(paramInputValue); + return { + resolved: paramResolutionResult.resolved, + resolutionLogs: paramResolutionResult.resolutionLogs.map( + (paramLog) => { + return { + ...paramLog, + inputParamName: `${String(inputParamKey)}.${paramLog.inputParamName}`, + }; + } + ), + }; + } + _resolverId: "ParamsResolver" = "ParamsResolver"; + })(); + } + + /** + * Core method of the parameters resolver that actually resolves the input parameters object `inputParams` + * into the resolved parameters object of the `ResolveToType` type. + * + * However, `resolve` does not return the resolved parameters object directly, instead, + * it returns the `ParamsResolutionResult` object containing both resolved object `resolved` + * (in case of a failure it is undefined) and the resolution logs `resolutionLogs`. + * You can find more details in the `ParamsResolutionResult`'s docs. + * + * This method `resolve` can be called multiple times, there is no inner state preventing from that. + * + * Finally, the `resolve` method throws errors only if it is not configured correctly + * (has properties of the wrong format or the single-parameter resolvers provided cannot produce a valid `ResolveToType` object `(*)`). + * If the reason of the resolution failure are `inputParams` parameters, the cause will be descibed in the resulting logs being returned + * and no error will be thrown. + * + * Note on `(*)`: unfortunately, an error can happen in one more case, namely, if the `inputParams` object + * has properties of wrong types (masked with `any`). In this case, the specified validation checks for these parameters + * might throw, or, if the checks luckily pass, the schema of `ResolveToType` will cause an error thrown in the end. + * Thus, be careful with using unsafe type casts or consider verifying `inputParams` with its Ajv JSON schema first. + */ + resolve(inputParams: InputType): ParamsResolutionResult { + const resolvedParamsObject: { [key: string]: any } = {}; + const resolutionLogs: SingleParamResolutionResult[] = []; + let resolutionFailed = false; + + for (const prop in this) { + if ( + !Object.prototype.hasOwnProperty.call(this, prop) || + prop.startsWith("_") + ) { + continue; + } + const paramResolver = this[prop] as ParamsResolver; + // no generic parametrization check in runtime is possible, unfortunately + if (!isParamsResolver(paramResolver)) { + throw Error( + `\`ParamsResolver\` is configured incorrectly because of \`${prop}\`: all properties should be built up to \`ParamsResolver\` type` + ); + } + const paramResolutionResult = paramResolver.resolve(inputParams); + resolutionLogs.push(...paramResolutionResult.resolutionLogs); + if (paramResolutionResult.resolved === undefined) { + resolutionFailed = true; + } else { + resolvedParamsObject[prop] = paramResolutionResult.resolved; + } + } + + if (resolutionFailed) { + return { + resolutionLogs: resolutionLogs, + }; + } + + const resolvedParams = resolvedParamsObject as ResolveToType; + if (!this._resolveToTypeValidator(resolvedParams)) { + throw Error( + `\`ParamsResolver\` is most likely configured incorrectly. Resulting object could not be interpreted as \`${this._resolveToTypeName}\`: ${ajvErrorsAsString(this._resolveToTypeValidator.errors as DefinedError[])}.` + ); + } + return { + resolved: resolvedParams, + resolutionLogs: resolutionLogs, + }; + } + + _resolverId: "ParamsResolver" = "ParamsResolver"; +} diff --git a/src/llm/llmServices/utils/time.ts b/src/llm/llmServices/utils/time.ts new file mode 100644 index 00000000..f2cb4aef --- /dev/null +++ b/src/llm/llmServices/utils/time.ts @@ -0,0 +1,98 @@ +export function nowTimestampMillis(): number { + return new Date().getTime(); +} + +export type TimeUnit = "millisecond" | "second" | "minute" | "hour" | "day"; + +export interface Time { + millis: number; + seconds: number; + minutes: number; + hours: number; + days: number; +} + +export function millisToTime(totalMillis: number): Time { + const totalSeconds = Math.floor(totalMillis / 1000); + const totalMinutes = Math.floor(totalSeconds / 60); + const totalHours = Math.floor(totalMinutes / 60); + const totalDays = Math.floor(totalHours / 24); + return { + millis: totalMillis % 1000, + seconds: totalSeconds % 60, + minutes: totalMinutes % 60, + hours: totalHours % 24, + days: totalDays, + }; +} + +export function timeToMillis(time: Time): number { + return ( + (((time.hours + time.days * 24) * 60 + time.minutes) * 60 + + time.seconds) * + 1000 + + time.millis + ); +} + +export function time(value: number, unit: TimeUnit): Time { + return millisToTime(timeInUnitsToMillis(value, unit)); +} + +export const timeZero: Time = { + millis: 0, + seconds: 0, + minutes: 0, + hours: 0, + days: 0, +}; + +export function timeToString(time: Time): string { + if (time === timeZero) { + return "0 ms"; + } + const resolvedTime = millisToTime(timeToMillis(time)); + + const days = `${resolvedTime.days} d`; + const hours = `${resolvedTime.hours} h`; + const minutes = `${resolvedTime.minutes} m`; + const seconds = `${resolvedTime.seconds} s`; + const millis = `${resolvedTime.millis} ms`; + + const orderedTimeItems = [ + [resolvedTime.days, days], + [resolvedTime.hours, hours], + [resolvedTime.minutes, minutes], + [resolvedTime.seconds, seconds], + [resolvedTime.millis, millis], + ]; + const fromIndex = orderedTimeItems.findIndex( + ([timeValue, _timeString]) => timeValue !== 0 + ); + const toIndex = + orderedTimeItems.length - + orderedTimeItems + .reverse() + .findIndex(([timeValue, _timeString]) => timeValue !== 0); + + return orderedTimeItems + .reverse() + .slice(fromIndex, toIndex) + .map(([_timeValue, timeString]) => timeString) + .join(", "); +} + +function timeInUnitsToMillis(value: number, unit: TimeUnit = "second"): number { + switch (unit) { + case "millisecond": + return value; + case "second": + return value * 1000; + case "minute": + return value * 1000 * 60; + case "hour": + return value * 1000 * 60 * 60; + case "day": + return value * 1000 * 60 * 60 * 24; + } +} diff --git a/src/llm/userModelParams.ts b/src/llm/userModelParams.ts index dd78e41e..6e7a0b8e 100644 --- a/src/llm/userModelParams.ts +++ b/src/llm/userModelParams.ts @@ -2,54 +2,70 @@ import { JSONSchemaType } from "ajv"; import { PropertiesSchema } from "ajv/dist/types/json-schema"; export interface UserMultiroundProfile { - // cannot be overriden: proof will always be updated no more than `maxRoundsNumber` times + /** + * Cannot be overriden in calls, i.e. + * proof will always be regenerated no more than `maxRoundsNumber` times. + */ maxRoundsNumber?: number; - // can be overriden in the `fixProof` call with the `choices` parameter + /** + * Can be overriden in the `fixProof` call with the `choices` parameter. + */ proofFixChoices?: number; - // use `${diagnostic}` syntax to include a diagnostic message into the prompt + /** + * Use `${diagnostic}` syntax to include a diagnostic message into the prompt. + */ proofFixPrompt?: string; } export interface UserModelParams { - modelName: string; + /** + * Can be any string, but must be unique for each specified model. + * It is used only to distinguish models from each other. + */ + modelId: string; + + /** + * Can be overriden in the generation-method call with the `choices` parameter. + */ choices?: number; systemPrompt?: string; - newMessageMaxTokens?: number; + maxTokensToGenerate?: number; + /** + * Includes tokens that the model generates as an answer message, + * i.e. should be greater than or equal to `maxTokensToGenerate`. + */ tokensLimit?: number; multiroundProfile?: UserMultiroundProfile; } +export interface PredefinedProofsUserModelParams extends UserModelParams { + /** + * List of tactics to try to solve the goal with. + */ + tactics: string[]; +} + export interface OpenAiUserModelParams extends UserModelParams { + modelName: string; temperature: number; apiKey: string; } export interface GrazieUserModelParams extends UserModelParams { + modelName: string; apiKey: string; } -export interface PredefinedProofsUserModelParams extends UserModelParams { - // A list of tactics to try to solve the goal with. - tactics: string[]; -} - export interface LMStudioUserModelParams extends UserModelParams { temperature: number; port: number; } -export interface UserModelsParams { - openAiParams: OpenAiUserModelParams[]; - grazieParams: GrazieUserModelParams[]; - predefinedProofsModelParams: PredefinedProofsUserModelParams[]; - lmStudioParams: LMStudioUserModelParams[]; -} - export const userMultiroundProfileSchema: JSONSchemaType = { type: "object", @@ -65,12 +81,12 @@ export const userMultiroundProfileSchema: JSONSchemaType export const userModelParamsSchema: JSONSchemaType = { type: "object", properties: { - modelName: { type: "string" }, + modelId: { type: "string" }, choices: { type: "number", nullable: true }, systemPrompt: { type: "string", nullable: true }, - newMessageMaxTokens: { type: "number", nullable: true }, + maxTokensToGenerate: { type: "number", nullable: true }, tokensLimit: { type: "number", nullable: true }, multiroundProfile: { @@ -79,44 +95,50 @@ export const userModelParamsSchema: JSONSchemaType = { nullable: true, }, }, - required: ["modelName"], + required: ["modelId"], + additionalProperties: false, }; -export const openAiUserModelParamsSchema: JSONSchemaType = +export const predefinedProofsUserModelParamsSchema: JSONSchemaType = { - title: "openAiModelsParameters", + title: "predefinedProofsModelsParameters", type: "object", properties: { - temperature: { type: "number" }, - apiKey: { type: "string" }, + tactics: { + type: "array", + items: { type: "string" }, + }, ...(userModelParamsSchema.properties as PropertiesSchema), }, - required: ["modelName", "temperature", "apiKey"], + required: ["modelId", "tactics"], + additionalProperties: false, }; -export const grazieUserModelParamsSchema: JSONSchemaType = +export const openAiUserModelParamsSchema: JSONSchemaType = { - title: "grazieModelsParameters", + title: "openAiModelsParameters", type: "object", properties: { + modelName: { type: "string" }, + temperature: { type: "number" }, apiKey: { type: "string" }, ...(userModelParamsSchema.properties as PropertiesSchema), }, - required: ["modelName", "apiKey"], + required: ["modelId", "modelName", "temperature", "apiKey"], + additionalProperties: false, }; -export const predefinedProofsUserModelParamsSchema: JSONSchemaType = +export const grazieUserModelParamsSchema: JSONSchemaType = { - title: "predefinedProofsModelsParameters", + title: "grazieModelsParameters", type: "object", properties: { - tactics: { - type: "array", - items: { type: "string" }, - }, + modelName: { type: "string" }, + apiKey: { type: "string" }, ...(userModelParamsSchema.properties as PropertiesSchema), }, - required: ["modelName", "tactics"], + required: ["modelId", "modelName", "apiKey"], + additionalProperties: false, }; export const lmStudioUserModelParamsSchema: JSONSchemaType = @@ -128,5 +150,6 @@ export const lmStudioUserModelParamsSchema: JSONSchemaType), }, - required: ["modelName", "temperature", "port"], + required: ["modelId", "temperature", "port"], + additionalProperties: false, }; diff --git a/src/logging/eventLogger.ts b/src/logging/eventLogger.ts index dccd3025..f9bcf6b8 100644 --- a/src/logging/eventLogger.ts +++ b/src/logging/eventLogger.ts @@ -1,17 +1,26 @@ -/* eslint-disable @typescript-eslint/naming-convention */ export enum Severity { - INFO = "INFO", - DEBUG = "DEBUG", + LOGIC, + INFO, + DEBUG, } -export const ALL_EVENTS = "all"; -/* eslint-enable @typescript-eslint/naming-convention */ +export const anyEventKeyword = "any"; + +export type SubscriptionId = number; + +interface EventSubscription { + id: SubscriptionId; + callback: (message: string, data?: any) => void; + severity: Severity; +} export class EventLogger { - events: { - [key: string]: Array<[(message: string, data?: any) => void, Severity]>; + private events: { + [key: string]: Array; }; + private newSubscriptionId: SubscriptionId = 0; + constructor() { this.events = {}; } @@ -20,11 +29,22 @@ export class EventLogger { event: string, severity: Severity, callback: (message: string, data?: any) => void - ): void { + ): SubscriptionId { if (this.events[event] === undefined) { this.events[event] = []; } - this.events[event].push([callback, severity]); + this.events[event].push({ + id: this.newSubscriptionId, + callback, + severity, + }); + return this.newSubscriptionId++; + } + + unsubscribe(event: string, subscriptionId: SubscriptionId) { + this.events[event] = this.events[event]?.filter((eventSubscription) => { + eventSubscription.id !== subscriptionId; + }); } log( @@ -32,17 +52,30 @@ export class EventLogger { message: string, data?: any, severity: Severity = Severity.INFO - ): void { - this.events[event]?.forEach(([callback, subscribedSeverity]) => { - if (subscribedSeverity === severity) { - callback(message, data); + ) { + this.events[event]?.forEach((eventSubscription) => { + if (eventSubscription.severity === severity) { + eventSubscription.callback(message, data); } }); - this.events[ALL_EVENTS]?.forEach(([callback, subscribedSeverity]) => { - if (subscribedSeverity === severity) { - callback(message, data); + this.events[anyEventKeyword]?.forEach((eventSubscription) => { + if (eventSubscription.severity === severity) { + eventSubscription.callback(message, data); } }); } + + subscribeToLogicEvent( + event: string, + callback: (data?: any) => void + ): SubscriptionId { + return this.subscribe(event, Severity.LOGIC, (_message, data) => + callback(data) + ); + } + + logLogicEvent(event: string, data?: any) { + this.log(event, "", data, Severity.LOGIC); + } } diff --git a/src/test/benchmark/benchmarkingFramework.ts b/src/test/benchmark/benchmarkingFramework.ts index 630a5c2e..ed11a9b2 100644 --- a/src/test/benchmark/benchmarkingFramework.ts +++ b/src/test/benchmark/benchmarkingFramework.ts @@ -3,9 +3,9 @@ import * as assert from "assert"; import { LLMServices } from "../../llm/llmServices"; import { GrazieService } from "../../llm/llmServices/grazie/grazieService"; import { LMStudioService } from "../../llm/llmServices/lmStudio/lmStudioService"; +import { ModelsParams } from "../../llm/llmServices/modelParams"; import { OpenAiService } from "../../llm/llmServices/openai/openAiService"; import { PredefinedProofsService } from "../../llm/llmServices/predefinedProofs/predefinedProofsService"; -import { UserModelsParams } from "../../llm/userModelParams"; import { CoqLspClient } from "../../coqLsp/coqLspClient"; import { CoqLspConfig } from "../../coqLsp/coqLspConfig"; @@ -25,12 +25,14 @@ import { createSourceFileEnvironment } from "../../core/inspectSourceFile"; import { ProofStep, Theorem } from "../../coqParser/parsedTypes"; import { Uri } from "../../utils/uri"; +import { resolveParametersOrThrow } from "../commonTestFunctions/resolveOrThrow"; -import { consoleLog, consoleLogLine } from "./loggingUtils"; +import { InputModelsParams } from "./inputModelsParams"; +import { consoleLog, consoleLogSeparatorLine } from "./loggingUtils"; export async function runTestBenchmark( filePath: string, - modelsParams: UserModelsParams, + inputModelsParams: InputModelsParams, specificTheoremsForBenchmark: string[] | undefined, benchmarkFullTheorems: Boolean = true, benchmarkAdmits: Boolean = true, @@ -42,7 +44,7 @@ export async function runTestBenchmark( const [completionTargets, sourceFileEnvironment, processEnvironment] = await prepareForBenchmarkCompletions( - modelsParams, + inputModelsParams, shouldCompleteHole, workspaceRootPath, filePath @@ -62,22 +64,22 @@ export async function runTestBenchmark( ), }; - consoleLogLine("\n"); + consoleLogSeparatorLine("\n"); let admitTargetsResults: BenchmarkResult | undefined = undefined; let theoremTargetsResults: BenchmarkResult | undefined = undefined; if (benchmarkAdmits) { - console.log("try to complete admits\n"); + consoleLog("try to complete admits\n"); admitTargetsResults = await benchmarkTargets( filteredCompletionTargets.admitTargets, sourceFileEnvironment, processEnvironment ); - console.log( + consoleLog( `BENCHMARK RESULT, ADMITS COMPLETED: ${admitTargetsResults}\n` ); - consoleLogLine("\n"); + consoleLogSeparatorLine("\n"); if (requireAllAdmitsCompleted) { assert.ok(admitTargetsResults.allCompleted()); @@ -85,16 +87,16 @@ export async function runTestBenchmark( } if (benchmarkFullTheorems) { - console.log("try to prove theorems\n"); + consoleLog("try to prove theorems\n"); theoremTargetsResults = await benchmarkTargets( filteredCompletionTargets.theoremTargets, sourceFileEnvironment, processEnvironment ); - console.log( + consoleLog( `BENCHMARK RESULT, THEOREMS PROVED: ${theoremTargetsResults}\n` ); - consoleLogLine(); + consoleLogSeparatorLine(); } return { @@ -166,11 +168,11 @@ async function benchmarkCompletionGeneration( processEnvironment: ProcessEnvironment ): Promise { const completionPosition = completionContext.admitEndPosition; - console.log( + consoleLog( `Completion position: ${completionPosition.line}:${completionPosition.character}` ); - console.log(`Theorem name: \`${completionContext.parentTheorem.name}\``); - console.log(`Proof goal: \`${goalToString(completionContext.proofGoal)}\``); + consoleLog(`Theorem name: \`${completionContext.parentTheorem.name}\``); + consoleLog(`Proof goal: \`${goalToString(completionContext.proofGoal)}\``); const sourceFileEnvironmentWithFilteredContext: SourceFileEnvironment = { ...sourceFileEnvironment, @@ -191,19 +193,19 @@ async function benchmarkCompletionGeneration( success = true; } else if (result instanceof FailureGenerationResult) { switch (result.status) { - case FailureGenerationStatus.excededTimeout: + case FailureGenerationStatus.TIMEOUT_EXCEEDED: message = "Timeout"; break; - case FailureGenerationStatus.exception: + case FailureGenerationStatus.ERROR_OCCURRED: message = `Exception: ${result.message}`; break; - case FailureGenerationStatus.searchFailed: + case FailureGenerationStatus.SEARCH_FAILED: message = "Proofs not found"; break; } } consoleLog(message, success ? "green" : "red"); - console.log(""); + consoleLog(""); return success; } @@ -212,7 +214,7 @@ function goalToString(proofGoal: Goal): string { } async function prepareForBenchmarkCompletions( - modelsParams: UserModelsParams, + inputModelsParams: InputModelsParams, shouldCompleteHole: (hole: ProofStep) => boolean, workspaceRootPath: string | undefined, filePath: string @@ -241,7 +243,10 @@ async function prepareForBenchmarkCompletions( }; const processEnvironment: ProcessEnvironment = { coqProofChecker: coqProofChecker, - modelsParams: modelsParams, + modelsParams: resolveInputModelsParametersOrThrow( + inputModelsParams, + llmServices + ), services: llmServices, }; @@ -358,3 +363,27 @@ async function resolveProofStepsToCompletionContexts( } return completionContexts; } + +function resolveInputModelsParametersOrThrow( + inputModelsParams: InputModelsParams, + llmServices: LLMServices +): ModelsParams { + return { + predefinedProofsModelParams: + inputModelsParams.predefinedProofsModelParams.map((inputParams) => + resolveParametersOrThrow( + llmServices.predefinedProofsService, + inputParams + ) + ), + openAiParams: inputModelsParams.openAiParams.map((inputParams) => + resolveParametersOrThrow(llmServices.openAiService, inputParams) + ), + grazieParams: inputModelsParams.grazieParams.map((inputParams) => + resolveParametersOrThrow(llmServices.grazieService, inputParams) + ), + lmStudioParams: inputModelsParams.lmStudioParams.map((inputParams) => + resolveParametersOrThrow(llmServices.lmStudioService, inputParams) + ), + }; +} diff --git a/src/test/benchmark/inputModelsParams.ts b/src/test/benchmark/inputModelsParams.ts new file mode 100644 index 00000000..9f6e5973 --- /dev/null +++ b/src/test/benchmark/inputModelsParams.ts @@ -0,0 +1,25 @@ +import { + GrazieUserModelParams, + LMStudioUserModelParams, + OpenAiUserModelParams, + PredefinedProofsUserModelParams, +} from "../../llm/userModelParams"; + +export interface InputModelsParams { + predefinedProofsModelParams: PredefinedProofsUserModelParams[]; + openAiParams: OpenAiUserModelParams[]; + grazieParams: GrazieUserModelParams[]; + lmStudioParams: LMStudioUserModelParams[]; +} + +export const onlyAutoModelsParams: InputModelsParams = { + openAiParams: [], + grazieParams: [], + predefinedProofsModelParams: [ + { + modelId: "Predefined `auto`", + tactics: ["auto."], + }, + ], + lmStudioParams: [], +}; diff --git a/src/test/benchmark/loggingUtils.ts b/src/test/benchmark/loggingUtils.ts index 91b28fa7..dfd2a1b2 100644 --- a/src/test/benchmark/loggingUtils.ts +++ b/src/test/benchmark/loggingUtils.ts @@ -1,9 +1,14 @@ +export const consoleLoggingIsMuted = true; + export type LogColor = "red" | "green" | "blue" | "magenta" | "reset"; export function consoleLog( message: string, color: LogColor | undefined = undefined ) { + if (consoleLoggingIsMuted) { + return; + } if (!color) { console.log(message); return; @@ -32,6 +37,6 @@ export function code(color: LogColor): string { throw Error(`unknown LogColor: ${color}`); } -export function consoleLogLine(suffix: string = "") { - console.log(`----------------------------${suffix}`); +export function consoleLogSeparatorLine(suffix: string = "") { + consoleLog(`----------------------------${suffix}`); } diff --git a/src/test/benchmark/presets.ts b/src/test/benchmark/presets.ts deleted file mode 100644 index 6d7bd1bc..00000000 --- a/src/test/benchmark/presets.ts +++ /dev/null @@ -1,21 +0,0 @@ -import { UserModelsParams } from "../../llm/userModelParams"; - -export const onlyAutoModelsParams: UserModelsParams = { - openAiParams: [], - grazieParams: [], - predefinedProofsModelParams: [ - { - modelName: "Predefined `auto`", - choices: undefined, - - systemPrompt: undefined, - - newMessageMaxTokens: undefined, - tokensLimit: undefined, - - multiroundProfile: undefined, - tactics: ["auto."], - }, - ], - lmStudioParams: [], -}; diff --git a/src/test/benchmark/runBenchmark.test.ts b/src/test/benchmark/runBenchmark.test.ts index f5b29e6a..0f478f1a 100644 --- a/src/test/benchmark/runBenchmark.test.ts +++ b/src/test/benchmark/runBenchmark.test.ts @@ -1,16 +1,20 @@ +import { expect } from "earl"; import * as fs from "fs"; import * as path from "path"; -import { UserModelsParams } from "../../llm/userModelParams"; - import { BenchmarkResult, runTestBenchmark } from "./benchmarkingFramework"; -import { code, consoleLogLine } from "./loggingUtils"; -import { onlyAutoModelsParams } from "./presets"; +import { InputModelsParams, onlyAutoModelsParams } from "./inputModelsParams"; +import { + code, + consoleLog, + consoleLogSeparatorLine, + consoleLoggingIsMuted, +} from "./loggingUtils"; interface Benchmark { name: string; items: DatasetItem[]; - modelsParams: UserModelsParams; + inputModelsParams: InputModelsParams; requireAllAdmitsCompleted: Boolean; benchmarkFullTheorems: Boolean; benchmarkAdmits: Boolean; @@ -35,7 +39,7 @@ class DatasetItem { const simpleAutoBenchmark: Benchmark = { name: "Complete simple examples with `auto`", items: [new DatasetItem("auto_benchmark.v")], - modelsParams: onlyAutoModelsParams, + inputModelsParams: onlyAutoModelsParams, requireAllAdmitsCompleted: true, benchmarkFullTheorems: true, benchmarkAdmits: true, @@ -45,7 +49,7 @@ const simpleAutoBenchmark: Benchmark = { const mixedAutoBenchmark: Benchmark = { name: "Complete mixed examples (both simple & hard) with `auto`", items: [new DatasetItem("mixed_benchmark.v")], - modelsParams: onlyAutoModelsParams, + inputModelsParams: onlyAutoModelsParams, requireAllAdmitsCompleted: false, benchmarkFullTheorems: true, benchmarkAdmits: true, @@ -55,6 +59,7 @@ const mixedAutoBenchmark: Benchmark = { const benchmarks: Benchmark[] = [simpleAutoBenchmark, mixedAutoBenchmark]; suite("Benchmark", () => { + expect(consoleLoggingIsMuted).toEqual(true); const datasetDir = getDatasetDir(); for (const benchmark of benchmarks) { @@ -80,7 +85,7 @@ suite("Benchmark", () => { const { admitsCompleted, theoremsProved } = await runTestBenchmark( resolvedFilePath, - benchmark.modelsParams, + benchmark.inputModelsParams, item.specificTheoremForBenchmark, benchmark.benchmarkFullTheorems, benchmark.benchmarkAdmits, @@ -96,15 +101,15 @@ suite("Benchmark", () => { } } - consoleLogLine(); - consoleLogLine("\n"); - console.log( + consoleLogSeparatorLine(); + consoleLogSeparatorLine("\n"); + consoleLog( `${code("magenta")}BENCHMARK REPORT:${code("reset")} ${benchmark.name}` ); - console.log( + consoleLog( `- ADMITS COMPLETED IN TOTAL: ${admitsCompletedInTotal}` ); - console.log( + consoleLog( `- THEOREMS PROVED IN TOTAL: ${theoremsProvedInTotal}\n` ); }).timeout(benchmark.timeoutMinutes * 60 * 1000); diff --git a/src/test/commonTestFunctions/checkProofs.ts b/src/test/commonTestFunctions/checkProofs.ts new file mode 100644 index 00000000..48afe61d --- /dev/null +++ b/src/test/commonTestFunctions/checkProofs.ts @@ -0,0 +1,46 @@ +import { GeneratedProof } from "../../llm/llmServices/llmService"; + +import { + CompletionContext, + getTextBeforePosition, + prepareProofToCheck, +} from "../../core/completionGenerator"; +import { ProofCheckResult } from "../../core/coqProofChecker"; + +import { PreparedEnvironment } from "./prepareEnvironment"; + +export async function checkProofs( + proofsToCheck: string[], + completionContext: CompletionContext, + environment: PreparedEnvironment +): Promise { + const sourceFileContentPrefix = getTextBeforePosition( + environment.sourceFileEnvironment.fileLines, + completionContext.prefixEndPosition + ); + return await environment.coqProofChecker.checkProofs( + environment.sourceFileEnvironment.dirPath, + sourceFileContentPrefix, + completionContext.prefixEndPosition, + proofsToCheck + ); +} + +export async function checkTheoremProven( + generatedProofs: GeneratedProof[], + completionContext: CompletionContext, + environment: PreparedEnvironment +) { + const proofsToCheck = generatedProofs.map((generatedProof) => + prepareProofToCheck(generatedProof.proof()) + ); + const checkResults = await checkProofs( + proofsToCheck, + completionContext, + environment + ); + const validProofs = checkResults.filter( + (checkResult) => checkResult.isValid + ).length; + return validProofs >= 1; +} diff --git a/src/test/commonTestFunctions/colorPrinter.ts b/src/test/commonTestFunctions/colorPrinter.ts new file mode 100644 index 00000000..4c506ed7 --- /dev/null +++ b/src/test/commonTestFunctions/colorPrinter.ts @@ -0,0 +1,26 @@ +export type Color = "red" | "green" | "yellow" | "blue" | "magenta" | "reset"; + +export function color(text: string, color: Color): string { + return `${code(color)}${text}${code("reset")}`; +} +function code(color: Color): string { + if (color === "reset") { + return "\x1b[0m"; + } + if (color === "red") { + return "\x1b[31m"; + } + if (color === "green") { + return "\x1b[32m"; + } + if (color === "yellow") { + return "\x1b[33m"; + } + if (color === "blue") { + return "\x1b[34m"; + } + if (color === "magenta") { + return "\x1b[35m"; + } + throw Error(`unknown Color: ${color}`); +} diff --git a/src/test/commonTestFunctions/conditionalTest.ts b/src/test/commonTestFunctions/conditionalTest.ts new file mode 100644 index 00000000..77a6e435 --- /dev/null +++ b/src/test/commonTestFunctions/conditionalTest.ts @@ -0,0 +1,17 @@ +import { color } from "./colorPrinter"; + +export function testIf( + condition: boolean, + testWillBeSkippedCause: string, + suiteName: string, + testName: string, + func: Mocha.Func +): Mocha.Test | undefined { + if (condition) { + return test(testName, func); + } + console.warn( + `${color("WARNING", "yellow")}: test will be skipped: \"${suiteName}\" # \"${testName}\"\n\t> cause: ${testWillBeSkippedCause}` + ); + return undefined; +} diff --git a/src/test/commonTestFunctions/coqFileParser.ts b/src/test/commonTestFunctions/coqFileParser.ts new file mode 100644 index 00000000..188abdf6 --- /dev/null +++ b/src/test/commonTestFunctions/coqFileParser.ts @@ -0,0 +1,25 @@ +import { parseCoqFile } from "../../coqParser/parseCoqFile"; +import { Theorem } from "../../coqParser/parsedTypes"; +import { Uri } from "../../utils/uri"; + +import { createCoqLspClient } from "./coqLspBuilder"; +import { resolveResourcesDir } from "./pathsResolver"; + +export async function parseTheoremsFromCoqFile( + resourcePath: string[], + projectRootPath?: string[] +): Promise { + const [filePath, rootDir] = resolveResourcesDir( + resourcePath, + projectRootPath + ); + + const fileUri = Uri.fromPath(filePath); + const client = createCoqLspClient(rootDir); + + await client.openTextDocument(fileUri); + const document = await parseCoqFile(fileUri, client); + await client.closeTextDocument(fileUri); + + return document; +} diff --git a/src/test/commonTestFunctions.ts b/src/test/commonTestFunctions/coqLspBuilder.ts similarity index 54% rename from src/test/commonTestFunctions.ts rename to src/test/commonTestFunctions/coqLspBuilder.ts index 00c10ceb..97954bb9 100644 --- a/src/test/commonTestFunctions.ts +++ b/src/test/commonTestFunctions/coqLspBuilder.ts @@ -1,12 +1,5 @@ -import * as path from "path"; - -import { CoqLspClient } from "../coqLsp/coqLspClient"; -import { CoqLspConfig } from "../coqLsp/coqLspConfig"; - -export function getResourceFolder() { - const dirname = path.dirname(path.dirname(__dirname)); - return path.join(dirname, "src", "test", "resources"); -} +import { CoqLspClient } from "../../coqLsp/coqLspClient"; +import { CoqLspConfig } from "../../coqLsp/coqLspConfig"; export function createCoqLspClient(workspaceRootPath?: string): CoqLspClient { const coqLspServerConfig = CoqLspConfig.createServerConfig(); diff --git a/src/test/commonTestFunctions/defaultLLMServicesBuilder.ts b/src/test/commonTestFunctions/defaultLLMServicesBuilder.ts new file mode 100644 index 00000000..a7e79d84 --- /dev/null +++ b/src/test/commonTestFunctions/defaultLLMServicesBuilder.ts @@ -0,0 +1,71 @@ +import { LLMServices } from "../../llm/llmServices"; +import { GrazieService } from "../../llm/llmServices/grazie/grazieService"; +import { LMStudioService } from "../../llm/llmServices/lmStudio/lmStudioService"; +import { + ModelsParams, + PredefinedProofsModelParams, +} from "../../llm/llmServices/modelParams"; +import { OpenAiService } from "../../llm/llmServices/openai/openAiService"; +import { PredefinedProofsModelParamsResolver } from "../../llm/llmServices/predefinedProofs/predefinedProofsModelParamsResolver"; +import { PredefinedProofsService } from "../../llm/llmServices/predefinedProofs/predefinedProofsService"; +import { PredefinedProofsUserModelParams } from "../../llm/userModelParams"; + +import { resolveOrThrow } from "./resolveOrThrow"; + +export function createDefaultServices(): LLMServices { + const predefinedProofsService = new PredefinedProofsService(); + const openAiService = new OpenAiService(); + const grazieService = new GrazieService(); + const lmStudioService = new LMStudioService(); + return { + predefinedProofsService, + openAiService, + grazieService, + lmStudioService, + }; +} + +export function createTrivialModelsParams( + predefinedProofsModelParams: PredefinedProofsModelParams[] = [] +): ModelsParams { + return { + predefinedProofsModelParams: predefinedProofsModelParams, + openAiParams: [], + grazieParams: [], + lmStudioParams: [], + }; +} + +export function createPredefinedProofsModel( + modelId: string = "predefined-proofs", + predefinedProofs: string[] = [ + "intros.", + "reflexivity.", + "auto.", + "assumption. intros.", + "left. reflexivity.", + ] +): PredefinedProofsModelParams { + const inputModelParams: PredefinedProofsUserModelParams = { + modelId: modelId, + tactics: predefinedProofs, + }; + return resolveOrThrow( + new PredefinedProofsModelParamsResolver(), + inputModelParams + ); +} + +export function createPredefinedProofsModelsParams( + predefinedProofs: string[] = [ + "intros.", + "reflexivity.", + "auto.", + "assumption. intros.", + "left. reflexivity.", + ] +): ModelsParams { + return createTrivialModelsParams([ + createPredefinedProofsModel("predefined-proofs", predefinedProofs), + ]); +} diff --git a/src/test/commonTestFunctions/delay.ts b/src/test/commonTestFunctions/delay.ts new file mode 100644 index 00000000..a330e8a3 --- /dev/null +++ b/src/test/commonTestFunctions/delay.ts @@ -0,0 +1,3 @@ +export async function delay(millis: number) { + return new Promise((resolve) => setTimeout(resolve, millis)); +} diff --git a/src/test/commonTestFunctions/pathsResolver.ts b/src/test/commonTestFunctions/pathsResolver.ts new file mode 100644 index 00000000..ba777a85 --- /dev/null +++ b/src/test/commonTestFunctions/pathsResolver.ts @@ -0,0 +1,19 @@ +import * as path from "path"; + +export function getRootDir(): string { + const relativeRoot = path.join(__dirname, "/../../.."); + return path.resolve(relativeRoot); +} + +export function getResourcesDir(): string { + return path.join(getRootDir(), "src", "test", "resources"); +} + +export function resolveResourcesDir( + resourcePath: string[], + projectRootPath?: string[] +): [string, string] { + const filePath = path.join(getResourcesDir(), ...resourcePath); + const rootDir = path.join(getResourcesDir(), ...(projectRootPath ?? [])); + return [filePath, rootDir]; +} diff --git a/src/test/commonTestFunctions/prepareEnvironment.ts b/src/test/commonTestFunctions/prepareEnvironment.ts new file mode 100644 index 00000000..f245a901 --- /dev/null +++ b/src/test/commonTestFunctions/prepareEnvironment.ts @@ -0,0 +1,74 @@ +import { ProofGenerationContext } from "../../llm/proofGenerationContext"; + +import { CoqLspClient } from "../../coqLsp/coqLspClient"; + +import { + CompletionContext, + SourceFileEnvironment, + buildProofGenerationContext, +} from "../../core/completionGenerator"; +import { CoqProofChecker } from "../../core/coqProofChecker"; +import { inspectSourceFile } from "../../core/inspectSourceFile"; + +import { Uri } from "../../utils/uri"; + +import { createCoqLspClient } from "./coqLspBuilder"; +import { resolveResourcesDir } from "./pathsResolver"; + +export interface PreparedEnvironment { + coqLspClient: CoqLspClient; + coqProofChecker: CoqProofChecker; + completionContexts: CompletionContext[]; + sourceFileEnvironment: SourceFileEnvironment; +} +/** + * Note: both paths should be relative to `src/test/resources/` folder. + */ +export async function prepareEnvironment( + resourcePath: string[], + projectRootPath?: string[] +): Promise { + const [filePath, rootDir] = resolveResourcesDir( + resourcePath, + projectRootPath + ); + const fileUri = Uri.fromPath(filePath); + + const client = createCoqLspClient(rootDir); + const coqProofChecker = new CoqProofChecker(client); + + await client.openTextDocument(fileUri); + const [completionContexts, sourceFileEnvironment] = await inspectSourceFile( + 1, + (_hole) => true, + fileUri, + client + ); + await client.closeTextDocument(fileUri); + + return { + coqLspClient: client, + coqProofChecker: coqProofChecker, + completionContexts: completionContexts, + sourceFileEnvironment: sourceFileEnvironment, + }; +} + +export async function prepareEnvironmentWithContexts( + resourcePath: string[], + projectRootPath?: string[] +): Promise< + [PreparedEnvironment, [CompletionContext, ProofGenerationContext][]] +> { + const environment = await prepareEnvironment(resourcePath, projectRootPath); + return [ + environment, + environment.completionContexts.map((completionContext) => [ + completionContext, + buildProofGenerationContext( + completionContext, + environment.sourceFileEnvironment.fileTheorems + ), + ]), + ]; +} diff --git a/src/test/commonTestFunctions/resolveOrThrow.ts b/src/test/commonTestFunctions/resolveOrThrow.ts new file mode 100644 index 00000000..48365db4 --- /dev/null +++ b/src/test/commonTestFunctions/resolveOrThrow.ts @@ -0,0 +1,43 @@ +import { ConfigurationError } from "../../llm/llmServiceErrors"; +import { LLMService } from "../../llm/llmServices/llmService"; +import { ModelParams } from "../../llm/llmServices/modelParams"; +import { ParamsResolutionResult } from "../../llm/llmServices/utils/paramsResolvers/abstractResolvers"; +import { ParamsResolverImpl } from "../../llm/llmServices/utils/paramsResolvers/paramsResolverImpl"; +import { UserModelParams } from "../../llm/userModelParams"; + +import { stringifyAnyValue } from "../../utils/printers"; + +export function resolveOrThrow( + paramsResolver: ParamsResolverImpl, + inputParams: InputType +): ResolveToType { + const resolutionResult = paramsResolver.resolve(inputParams); + return unpackResolvedParamsOrThrow(resolutionResult, inputParams); +} + +export function resolveParametersOrThrow< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, +>( + llmService: LLMService, + inputParams: InputModelParams +): ResolvedModelParams { + const resolutionResult = llmService.resolveParameters(inputParams); + return unpackResolvedParamsOrThrow(resolutionResult, inputParams); +} + +function unpackResolvedParamsOrThrow( + resolutionResult: ParamsResolutionResult, + inputParams: InputType +): ResolveToType { + if (resolutionResult.resolved !== undefined) { + return resolutionResult.resolved; + } + const joinedErrorLogs = resolutionResult.resolutionLogs + .filter((paramLog) => paramLog.isInvalidCause !== undefined) + .map((paramLog) => paramLog.isInvalidCause) + .join("; "); + throw new ConfigurationError( + `parameters ${stringifyAnyValue(inputParams)} could not be resolved: ${joinedErrorLogs}` + ); +} diff --git a/src/test/commonTestFunctions/withLLMService.ts b/src/test/commonTestFunctions/withLLMService.ts new file mode 100644 index 00000000..a041c027 --- /dev/null +++ b/src/test/commonTestFunctions/withLLMService.ts @@ -0,0 +1,45 @@ +import { LLMService } from "../../llm/llmServices/llmService"; +import { ModelParams } from "../../llm/llmServices/modelParams"; +import { UserModelParams } from "../../llm/userModelParams"; + +import { resolveParametersOrThrow } from "./resolveOrThrow"; + +export async function withLLMService< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, + LLMServiceType extends LLMService, + T, +>( + llmService: LLMServiceType, + block: (llmService: LLMServiceType) => Promise +): Promise { + try { + return await block(llmService); + } finally { + llmService.dispose(); + } +} + +export async function withLLMServiceAndParams< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, + LLMServiceType extends LLMService, + T, +>( + llmService: LLMServiceType, + inputParams: InputModelParams, + block: ( + llmService: LLMServiceType, + resolvedParams: ResolvedModelParams + ) => Promise +): Promise { + try { + const resolvedParams = resolveParametersOrThrow( + llmService, + inputParams + ); + return await block(llmService, resolvedParams); + } finally { + llmService.dispose(); + } +} diff --git a/src/test/coqLsp/coqLspGetGoals.test.ts b/src/test/coqLsp/coqLspGetGoals.test.ts index 207f35cb..cbf1bb3c 100644 --- a/src/test/coqLsp/coqLspGetGoals.test.ts +++ b/src/test/coqLsp/coqLspGetGoals.test.ts @@ -1,10 +1,10 @@ import { expect } from "earl"; -import * as path from "path"; import { Goal, PpString } from "../../coqLsp/coqLspTypes"; import { Uri } from "../../utils/uri"; -import { createCoqLspClient, getResourceFolder } from "../commonTestFunctions"; +import { createCoqLspClient } from "../commonTestFunctions/coqLspBuilder"; +import { resolveResourcesDir } from "../commonTestFunctions/pathsResolver"; suite("Retrieve goals from Coq file", () => { async function getGoalsAtPoints( @@ -12,15 +12,13 @@ suite("Retrieve goals from Coq file", () => { resourcePath: string[], projectRootPath?: string[] ): Promise<(Goal | Error)[]> { - const filePath = path.join(getResourceFolder(), ...resourcePath); - const rootDir = path.join( - getResourceFolder(), - ...(projectRootPath ?? []) + const [filePath, rootDir] = resolveResourcesDir( + resourcePath, + projectRootPath ); - const fileUri = Uri.fromPath(filePath); - const client = createCoqLspClient(rootDir); + const client = createCoqLspClient(rootDir); await client.openTextDocument(fileUri); const goals = await Promise.all( points.map(async (point) => { @@ -116,7 +114,7 @@ suite("Retrieve goals from Coq file", () => { } }); - test("Retreive goal in project with imports --non-ci", async () => { + test("Retreive goal in project with imports", async () => { const goals = await getGoalsAtPoints( [ { line: 4, character: 4 }, diff --git a/src/test/coqParser/parseCoqFile.test.ts b/src/test/coqParser/parseCoqFile.test.ts index 53e737dc..dc4701c2 100644 --- a/src/test/coqParser/parseCoqFile.test.ts +++ b/src/test/coqParser/parseCoqFile.test.ts @@ -1,34 +1,10 @@ import { expect } from "earl"; -import * as path from "path"; -import { parseCoqFile } from "../../coqParser/parseCoqFile"; -import { Theorem } from "../../coqParser/parsedTypes"; -import { Uri } from "../../utils/uri"; -import { createCoqLspClient, getResourceFolder } from "../commonTestFunctions"; +import { parseTheoremsFromCoqFile } from "../commonTestFunctions/coqFileParser"; suite("Coq file parser tests", () => { - async function getCoqDocument( - resourcePath: string[], - projectRootPath?: string[] - ): Promise { - const filePath = path.join(getResourceFolder(), ...resourcePath); - const rootDir = path.join( - getResourceFolder(), - ...(projectRootPath ?? []) - ); - - const fileUri = Uri.fromPath(filePath); - const client = createCoqLspClient(rootDir); - - await client.openTextDocument(fileUri); - const document = await parseCoqFile(fileUri, client); - await client.closeTextDocument(fileUri); - - return document; - } - test("Parse simple small document", async () => { - const doc = await getCoqDocument(["small_document.v"]); + const doc = await parseTheoremsFromCoqFile(["small_document.v"]); const theoremData = [ { @@ -75,7 +51,7 @@ suite("Coq file parser tests", () => { }); test("Retreive Multiple nested holes", async () => { - const doc = await getCoqDocument(["test_many_admits.v"]); + const doc = await parseTheoremsFromCoqFile(["test_many_admits.v"]); const expectedHoleRanges = [ { @@ -113,7 +89,7 @@ suite("Coq file parser tests", () => { }); test("Test different theorem declarations", async () => { - const doc = await getCoqDocument(["test_parse_proof.v"]); + const doc = await parseTheoremsFromCoqFile(["test_parse_proof.v"]); const theoremData = [ "test_1", @@ -133,8 +109,8 @@ suite("Coq file parser tests", () => { expect(theoremsWithoutProof[0].name).toEqual("test_5"); }); - test("Test parse file which is part of project --non-ci", async () => { - const doc = await getCoqDocument( + test("Test parse file which is part of project", async () => { + const doc = await parseTheoremsFromCoqFile( ["coqProj", "theories", "B.v"], ["coqProj"] ); diff --git a/src/test/core/completionGenerator.test.ts b/src/test/core/completionGenerator.test.ts index 4f4d86c2..89323fc6 100644 --- a/src/test/core/completionGenerator.test.ts +++ b/src/test/core/completionGenerator.test.ts @@ -1,10 +1,6 @@ import { expect } from "earl"; -import * as path from "path"; -import { GrazieService } from "../../llm/llmServices/grazie/grazieService"; -import { LMStudioService } from "../../llm/llmServices/lmStudio/lmStudioService"; -import { OpenAiService } from "../../llm/llmServices/openai/openAiService"; -import { PredefinedProofsService } from "../../llm/llmServices/predefinedProofs/predefinedProofsService"; +import { disposeServices } from "../../llm/llmServices"; import { FailureGenerationResult, @@ -14,11 +10,12 @@ import { } from "../../core/completionGenerator"; import { ProcessEnvironment } from "../../core/completionGenerator"; import { SuccessGenerationResult } from "../../core/completionGenerator"; -import { CoqProofChecker } from "../../core/coqProofChecker"; -import { inspectSourceFile } from "../../core/inspectSourceFile"; -import { Uri } from "../../utils/uri"; -import { createCoqLspClient, getResourceFolder } from "../commonTestFunctions"; +import { + createDefaultServices, + createPredefinedProofsModelsParams, +} from "../commonTestFunctions/defaultLLMServicesBuilder"; +import { prepareEnvironment } from "../commonTestFunctions/prepareEnvironment"; suite("Completion generation tests", () => { async function generateCompletionForAdmitsFromFile( @@ -26,57 +23,31 @@ suite("Completion generation tests", () => { predefinedProofs: string[], projectRootPath?: string[] ): Promise { - const filePath = path.join(getResourceFolder(), ...resourcePath); - const rootDir = path.join( - getResourceFolder(), - ...(projectRootPath ?? []) + const environment = await prepareEnvironment( + resourcePath, + projectRootPath ); - - const fileUri = Uri.fromPath(filePath); - const client = createCoqLspClient(rootDir); - const coqProofChecker = new CoqProofChecker(client); - await client.openTextDocument(fileUri); - const [completionContexts, sourceFileEnvironment] = - await inspectSourceFile(1, (_hole) => true, fileUri, client); - await client.closeTextDocument(fileUri); - - const openAiService = new OpenAiService(); - const grazieService = new GrazieService(); - const predefinedProofsService = new PredefinedProofsService(); - const lmStudioService = new LMStudioService(); - const processEnvironment: ProcessEnvironment = { - coqProofChecker: coqProofChecker, - modelsParams: { - openAiParams: [], - grazieParams: [], - predefinedProofsModelParams: [ - { - modelName: "Doesn't matter", - tactics: predefinedProofs, - }, - ], - lmStudioParams: [], - }, - services: { - openAiService, - grazieService, - predefinedProofsService, - lmStudioService, - }, + coqProofChecker: environment.coqProofChecker, + modelsParams: createPredefinedProofsModelsParams(predefinedProofs), + services: createDefaultServices(), }; - - return Promise.all( - completionContexts.map(async (completionContext) => { - const result = await generateCompletion( - completionContext, - sourceFileEnvironment, - processEnvironment - ); - - return result; - }) - ); + try { + return await Promise.all( + environment.completionContexts.map( + async (completionContext) => { + const result = await generateCompletion( + completionContext, + environment.sourceFileEnvironment, + processEnvironment + ); + return result; + } + ) + ); + } finally { + disposeServices(processEnvironment.services); + } } function unpackProof(text: string): string { @@ -142,11 +113,11 @@ suite("Completion generation tests", () => { ); expect(results[1]).toBeA(FailureGenerationResult); expect((results[1] as FailureGenerationResult).status).toEqual( - FailureGenerationStatus.searchFailed + FailureGenerationStatus.SEARCH_FAILED ); }).timeout(2000); - test("Check generation in project --non-ci", async () => { + test("Check generation in project", async () => { const resourcePath = ["coqProj", "theories", "C.v"]; const predefinedProofs = ["intros.", "auto."]; const projectRootPath = ["coqProj"]; diff --git a/src/test/core/coqProofChecker.test.ts b/src/test/core/coqProofChecker.test.ts index da58c92e..1d9caa42 100644 --- a/src/test/core/coqProofChecker.test.ts +++ b/src/test/core/coqProofChecker.test.ts @@ -5,21 +5,21 @@ import * as path from "path"; import { CoqProofChecker } from "../../core/coqProofChecker"; import { ProofCheckResult } from "../../core/coqProofChecker"; -import { createCoqLspClient, getResourceFolder } from "../commonTestFunctions"; +import { createCoqLspClient } from "../commonTestFunctions/coqLspBuilder"; +import { resolveResourcesDir } from "../commonTestFunctions/pathsResolver"; -suite("Coq Proof Checker tests", () => { +suite("`CoqProofChecker` tests", () => { async function checkProofsForAdmitsFromFile( resourcePath: string[], positions: { line: number; character: number }[], proofsToCheck: string[][], projectRootPath?: string[] ): Promise { - const filePath = path.join(getResourceFolder(), ...resourcePath); - const fileDir = path.dirname(filePath); - const rootDir = path.join( - getResourceFolder(), - ...(projectRootPath ?? []) + const [filePath, rootDir] = resolveResourcesDir( + resourcePath, + projectRootPath ); + const fileDir = path.dirname(filePath); const fileLines = readFileSync(filePath).toString().split("\n"); const client = createCoqLspClient(rootDir); diff --git a/src/test/index.ts b/src/test/index.ts index 87e6ff3d..6adeb516 100644 --- a/src/test/index.ts +++ b/src/test/index.ts @@ -25,7 +25,7 @@ export function run(): Promise { try { mocha.run((failures) => { if (failures > 0) { - e(new Error(`${failures} tests failed.`)); + e(Error(`${failures} tests failed.`)); } else { c(); } diff --git a/src/test/llm/chatTokensFitter.test.ts b/src/test/llm/chatTokensFitter.test.ts deleted file mode 100644 index 10bcb79a..00000000 --- a/src/test/llm/chatTokensFitter.test.ts +++ /dev/null @@ -1,202 +0,0 @@ -import { expect } from "earl"; -import * as path from "path"; -import { TiktokenModel, encoding_for_model } from "tiktoken"; - -import { theoremToChatItem } from "../../llm/llmServices/utils/chatFactory"; -import { ChatTokensFitter } from "../../llm/llmServices/utils/chatTokensFitter"; -import { chatItemToContent } from "../../llm/llmServices/utils/chatUtils"; - -import { parseCoqFile } from "../../coqParser/parseCoqFile"; -import { Theorem } from "../../coqParser/parsedTypes"; -import { Uri } from "../../utils/uri"; -import { createCoqLspClient, getResourceFolder } from "../commonTestFunctions"; - -suite("Chat tokens fitter tests", () => { - function calculateTokensViaTikToken( - text: string, - model: TiktokenModel - ): number { - const encoder = encoding_for_model(model); - const tokens = encoder.encode(text).length; - encoder.free(); - - return tokens; - } - - function approxCalculateTokens(text: string): number { - return (text.length / 4) >> 0; - } - - async function getCoqDocument( - resourcePath: string[], - projectRootPath?: string[] - ): Promise { - const filePath = path.join(getResourceFolder(), ...resourcePath); - const rootDir = path.join( - getResourceFolder(), - ...(projectRootPath ?? []) - ); - - const fileUri = Uri.fromPath(filePath); - const client = createCoqLspClient(rootDir); - - await client.openTextDocument(fileUri); - const document = await parseCoqFile(fileUri, client); - await client.closeTextDocument(fileUri); - - return document; - } - - function countTheoremsPickedFromContext( - systemMessage: string, - completionTarget: string, - theorems: Theorem[], - model: string, - newMessageMaxTokens: number, - tokensLimit: number - ): number { - const fitter = new ChatTokensFitter( - model, - newMessageMaxTokens, - tokensLimit - ); - - fitter.fitRequiredMessage({ - role: "system", - content: systemMessage, - }); - - fitter.fitRequiredMessage({ - role: "user", - content: completionTarget, - }); - - const fittedTheorems = fitter.fitOptionalObjects(theorems, (theorem) => - chatItemToContent(theoremToChatItem(theorem)) - ); - - return fittedTheorems.length; - } - - test("Empty theorems array", async () => { - const theorems: Theorem[] = []; - const answer = countTheoremsPickedFromContext( - "You are a friendly assistant", - "doesn't matter", - theorems, - "openai-gpt", - 100, - 1000 - ); - - expect(answer).toEqual(0); - }); - - test("Two theorems, but overflow", async () => { - const theorems: Theorem[] = await getCoqDocument(["small_document.v"]); - expect(() => { - countTheoremsPickedFromContext( - "You are a friendly assistant", - "doesn't matter", - theorems, - "openai-gpt", - 1000, - 1000 - ); - }).toThrow(); - }); - - test("Two theorems, no overflow", async () => { - const theorems: Theorem[] = await getCoqDocument(["small_document.v"]); - const answer = countTheoremsPickedFromContext( - "You are a friendly assistant", - "doesn't matter", - theorems, - "openai-gpt", - 1000, - 10000 - ); - - expect(answer).toEqual(2); - }); - - test("Two theorems, overflow after first", async () => { - const theorems: Theorem[] = await getCoqDocument(["small_document.v"]); - - const statementTokens = approxCalculateTokens(theorems[0].statement); - const theoremProof = theorems[0].proof?.onlyText() ?? ""; - const proofTokens = approxCalculateTokens(theoremProof); - const answer = countTheoremsPickedFromContext( - "", - "", - theorems, - "invalid-model", - 1000, - 1000 + statementTokens + proofTokens - ); - - expect(answer).toEqual(1); - }); - - test("Two theorems, overflow almost before first", async () => { - const theorems: Theorem[] = await getCoqDocument(["small_document.v"]); - - const statementTokens = approxCalculateTokens(theorems[0].statement); - const theoremProof = theorems[0].proof?.onlyText() ?? ""; - const proofTokens = approxCalculateTokens(theoremProof); - - const answer = countTheoremsPickedFromContext( - "", - "", - theorems, - "invalid-model", - 1000, - 1000 + statementTokens + proofTokens - 1 - ); - - expect(answer).toEqual(0); - }); - - test("Two theorems, overflow after first with tiktoken", async () => { - const theorems: Theorem[] = await getCoqDocument(["small_document.v"]); - const model = "gpt-3.5-turbo-0301"; - - const statementTokens = calculateTokensViaTikToken( - theorems[0].statement, - model - ); - const theoremProof = theorems[0].proof?.onlyText() ?? ""; - const proofTokens = calculateTokensViaTikToken(theoremProof, model); - const answer = countTheoremsPickedFromContext( - "", - "", - theorems, - model, - 1000, - 1000 + statementTokens + proofTokens - ); - - expect(answer).toEqual(1); - }); - - test("Test if two tokenizers are similar: Small text", async () => { - const model = "gpt-3.5-turbo-0301"; - - const text = "This is a test text"; - const tokens1 = calculateTokensViaTikToken(text, model); - const tokens2 = approxCalculateTokens(text); - - expect(tokens1).toBeCloseTo(tokens2, 2); - }); - - test("Test if two tokenizers are similar: Big text", async () => { - const model = "gpt-3.5-turbo-0301"; - - const text = - "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; - const tokens1 = calculateTokensViaTikToken(text, model); - const tokens2 = approxCalculateTokens(text); - - expect(tokens1).toBeCloseTo(tokens2, 20); - }); -}); diff --git a/src/test/llm/llmIterator.test.ts b/src/test/llm/llmIterator.test.ts index c8a46c97..28c83e05 100644 --- a/src/test/llm/llmIterator.test.ts +++ b/src/test/llm/llmIterator.test.ts @@ -1,68 +1,74 @@ import { expect } from "earl"; import { LLMSequentialIterator } from "../../llm/llmIterator"; -import { GrazieService } from "../../llm/llmServices/grazie/grazieService"; -import { LMStudioService } from "../../llm/llmServices/lmStudio/lmStudioService"; -import { OpenAiService } from "../../llm/llmServices/openai/openAiService"; -import { PredefinedProofsService } from "../../llm/llmServices/predefinedProofs/predefinedProofsService"; +import { disposeServices } from "../../llm/llmServices"; +import { GeneratedProof } from "../../llm/llmServices/llmService"; import { ProofGenerationContext } from "../../llm/proofGenerationContext"; -import { UserModelsParams } from "../../llm/userModelParams"; + +import { + createDefaultServices, + createPredefinedProofsModel, + createTrivialModelsParams, +} from "../commonTestFunctions/defaultLLMServicesBuilder"; suite("LLM Iterator test", () => { - function getProofFromPredefinedCoqSentance(proof: string): string { - return `Proof. ${proof} Qed.`; - } + const predefinedModel1 = createPredefinedProofsModel("first model"); + const predefinedModel2 = createPredefinedProofsModel("second model"); + const modelsParams = createTrivialModelsParams([ + predefinedModel1, + predefinedModel2, + ]); + const tactics = predefinedModel1.tactics; + expect(predefinedModel2.tactics).toEqual(tactics); - test("Simple test of the iterator via predef proofs", async () => { - const openAiService = new OpenAiService(); - const grazieService = new GrazieService(); - const predefinedProofsService = new PredefinedProofsService(); - const lmStudioService = new LMStudioService(); - const predefinedProofs = [ - "intros.", - "reflexivity.", - "auto.", - "assumption. intros.", - "left. reflexivity.", - ]; - const modelsParams: UserModelsParams = { - openAiParams: [], - grazieParams: [], - predefinedProofsModelParams: [ - { - modelName: "Doesn't matter", - tactics: predefinedProofs, - }, - ], - lmStudioParams: [], - }; - const services = { - openAiService, - grazieService, - predefinedProofsService, - lmStudioService, - }; - const proofGenerationContext: ProofGenerationContext = { - contextTheorems: [], - completionTarget: "doesn't matter", - }; - const iterator = new LLMSequentialIterator( - proofGenerationContext, - modelsParams, - services - ); + const proofGenerationContext: ProofGenerationContext = { + contextTheorems: [], + completionTarget: "doesn't matter", + }; - let i = 0; - while (true) { - const result = await iterator.nextProof(); - if (result.done) { - break; + test("Test `nextProof` via two predefined-proofs models", async () => { + const services = createDefaultServices(); + try { + const iterator = new LLMSequentialIterator( + proofGenerationContext, + modelsParams, + services + ); + for (let i = 0; i < 2; i++) { + for (let t = 0; t < tactics.length; t++) { + const result = await iterator.nextProof(); + expect(result.done).toBeFalsy(); + const proof = result.value; + expect(proof.proof()).toEqual(tactics[t]); + } } - const proof = result.value; - expect(proof.proof()).toEqual( - getProofFromPredefinedCoqSentance(predefinedProofs[i]) + const result = await iterator.nextProof(); + expect(result.done); + } finally { + disposeServices(services); + } + }); + + test("Test `next` via two predefined-proofs models", async () => { + const services = createDefaultServices(); + try { + const iterator = new LLMSequentialIterator( + proofGenerationContext, + modelsParams, + services ); - i++; + for (let i = 0; i < 2; i++) { + const result = await iterator.next(); + expect(result.done).toBeFalsy(); + const proofsBatch = result.value.map( + (proofObject: GeneratedProof) => proofObject.proof() + ); + expect(proofsBatch).toEqual(tactics); + } + const result = await iterator.next(); + expect(result.done); + } finally { + disposeServices(services); } }); }); diff --git a/src/test/llm/llmServices/grazieService.test.ts b/src/test/llm/llmServices/grazieService.test.ts new file mode 100644 index 00000000..6b50e9dc --- /dev/null +++ b/src/test/llm/llmServices/grazieService.test.ts @@ -0,0 +1,116 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../llm/llmServiceErrors"; +import { GrazieService } from "../../../llm/llmServices/grazie/grazieService"; +import { ErrorsHandlingMode } from "../../../llm/llmServices/llmService"; +import { GrazieModelParams } from "../../../llm/llmServices/modelParams"; +import { defaultSystemMessageContent } from "../../../llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers"; +import { GrazieUserModelParams } from "../../../llm/userModelParams"; + +import { testIf } from "../../commonTestFunctions/conditionalTest"; +import { resolveParametersOrThrow } from "../../commonTestFunctions/resolveOrThrow"; +import { + withLLMService, + withLLMServiceAndParams, +} from "../../commonTestFunctions/withLLMService"; +import { + mockProofGenerationContext, + testModelId, +} from "../llmSpecificTestUtils/constants"; +import { testLLMServiceCompletesAdmitFromFile } from "../llmSpecificTestUtils/testAdmitCompletion"; +import { + defaultUserMultiroundProfile, + testResolveValidCompleteParameters, +} from "../llmSpecificTestUtils/testResolveParameters"; + +suite("[LLMService] Test `GrazieService`", function () { + const apiKey = process.env.GRAZIE_API_KEY; + const choices = 15; + const inputFile = ["small_document.v"]; + + const requiredInputParamsTemplate = { + modelId: testModelId, + modelName: "openai-gpt-4", + choices: choices, + maxTokensToGenerate: 2000, + tokensLimit: 4000, + }; + + testIf( + apiKey !== undefined, + "`GRAZIE_API_KEY` is not specified", + this.title, + `Simple generation: 1 request, ${choices} choices`, + async () => { + const inputParams: GrazieUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: apiKey!, + }; + const grazieService = new GrazieService(); + await testLLMServiceCompletesAdmitFromFile( + grazieService, + inputParams, + inputFile, + choices + ); + } + )?.timeout(10000); + + test("Test `resolveParameters` reads & accepts valid params", async () => { + const inputParams: GrazieUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + }; + await withLLMService(new GrazieService(), async (grazieService) => { + testResolveValidCompleteParameters(grazieService, inputParams); + testResolveValidCompleteParameters( + grazieService, + { + ...inputParams, + systemPrompt: defaultSystemMessageContent, + multiroundProfile: defaultUserMultiroundProfile, + }, + true + ); + }); + }); + + test("Resolve parameters with predefined `maxTokensToGenerate`", async () => { + const inputParams: GrazieUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + maxTokensToGenerate: 6666, // should be overriden by GrazieService + }; + withLLMService(new GrazieService(), async (grazieService) => { + const resolvedParams = resolveParametersOrThrow( + grazieService, + inputParams + ); + expect(resolvedParams.maxTokensToGenerate).toEqual( + GrazieService.maxTokensToGeneratePredefined + ); + }); + }); + + test("Test `generateProof` throws on invalid `choices`", async () => { + const inputParams: GrazieUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + }; + await withLLMServiceAndParams( + new GrazieService(), + inputParams, + async (grazieService, resolvedParams: GrazieModelParams) => { + // non-positive choices + expect(async () => { + await grazieService.generateProof( + mockProofGenerationContext, + resolvedParams, + -1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith(ConfigurationError, "choices"); + } + ); + }); +}); diff --git a/src/test/llm/llmServices/llmService/availabilityEstimator.test.ts b/src/test/llm/llmServices/llmService/availabilityEstimator.test.ts new file mode 100644 index 00000000..ee8d0b15 --- /dev/null +++ b/src/test/llm/llmServices/llmService/availabilityEstimator.test.ts @@ -0,0 +1,137 @@ +import { expect } from "earl"; + +import { estimateTimeToBecomeAvailableDefault } from "../../../../llm/llmServices/utils/defaultAvailabilityEstimator"; +import { + LoggerRecord, + ResponseStatus, +} from "../../../../llm/llmServices/utils/generationsLogger/loggerRecord"; +import { + Time, + nowTimestampMillis, + time, + timeToMillis, + timeToString, + timeZero, +} from "../../../../llm/llmServices/utils/time"; + +suite("[LLMService] Test default availability estimator", () => { + function buildNextRecord( + timestampMillis: number, + timeDelta: Time, + responseStatus: ResponseStatus = "FAILURE" + ): LoggerRecord { + return new LoggerRecord( + timestampMillis + timeToMillis(timeDelta), + "test model", + responseStatus, + 5, + undefined, + responseStatus === "FAILURE" + ? { + typeName: Error.name, + message: "connection error", + } + : undefined + ); + } + + function testAvailabilityEstimation( + logs: LoggerRecord[], + expectedEstimation: Time, + nowMillis: number = nowTimestampMillis() + ) { + const actualEstimation = estimateTimeToBecomeAvailableDefault( + logs, + nowMillis + ); + expect(actualEstimation).toEqual(expectedEstimation); + } + + const lastSuccessMillis = nowTimestampMillis(); + + test("No failures", () => { + testAvailabilityEstimation([], timeZero); + }); + + [time(100, "millisecond"), time(1, "second"), time(1, "day")].forEach( + (failureTimeDelta) => { + test(`Single failure in <${timeToString(failureTimeDelta)}>`, () => { + const failure = buildNextRecord( + lastSuccessMillis, + failureTimeDelta + ); + testAvailabilityEstimation([failure], time(1, "second")); + }); + } + ); + + [ + [timeZero, timeZero, time(1, "second")], // check zero + // check algorithm's logic + [time(4, "minute"), timeZero, time(5, "minute")], // delay is expected to be > 4 minutes, the closest is 5 minutes + [time(4, "minute"), time(1, "hour"), timeZero], // 5 minutes already passed (1 hour passed), no need to wait + [time(4, "minute"), time(3, "minute"), time(2, "minute")], // 3 out of 5 minutes already passed, need to wait 2 more + [time(4, "minute"), time(1, "second"), time(5, "minute")], // only 1 second passed, let's round estimation to 5 minutes still + // check other heuristic estimations points + [time(40, "minute"), timeZero, time(1, "hour")], + [time(13, "hour"), timeZero, time(1, "day")], + [time(2, "day"), timeZero, time(1, "day")], // check out-of-heuristic-estimations interval + ].forEach(([interval, timeFromLastAttempt, expectedEstimate]) => { + test(`Two failures with <${timeToString(interval)}> interval, <${timeToString(timeFromLastAttempt)}> from last attempt`, () => { + const firstFailure = buildNextRecord( + lastSuccessMillis, + time(1, "second") + ); + const secondFailure = buildNextRecord( + firstFailure.timestampMillis, + interval + ); + testAvailabilityEstimation( + [firstFailure, secondFailure], + expectedEstimate, + secondFailure.timestampMillis + + timeToMillis(timeFromLastAttempt) + ); + }); + }); + + function buildFailureRecordsSequence(timeDeltas: Time[]): LoggerRecord[] { + const records: LoggerRecord[] = [ + buildNextRecord(lastSuccessMillis, time(1, "second")), + ]; + for (const timeDelta of timeDeltas) { + records.push( + buildNextRecord( + records[records.length - 1].timestampMillis, + timeDelta + ) + ); + } + return records; + } + + test(`Many failures`, () => { + const records = buildFailureRecordsSequence([ + time(1, "second"), + time(20, "second"), + time(3, "minute"), + time(1, "second"), + time(1, "second"), + time(4, "minute"), // max interval + time(1, "minute"), + ]); + testAvailabilityEstimation( + records, + time(5, "minute"), // max interval between failures is 4 minutes + records[records.length - 1].timestampMillis // i.e. `timeFromLastAttempt` is 0 + ); + }); + + test("Throw on invalid input logs", () => { + expect(() => + estimateTimeToBecomeAvailableDefault([ + buildNextRecord(lastSuccessMillis, timeZero, "SUCCESS"), + ]) + ).toThrow(Error); + }); +}); diff --git a/src/test/llm/llmServices/llmService/generateFromChat.test.ts b/src/test/llm/llmServices/llmService/generateFromChat.test.ts new file mode 100644 index 00000000..a505dfd2 --- /dev/null +++ b/src/test/llm/llmServices/llmService/generateFromChat.test.ts @@ -0,0 +1,69 @@ +import { expect } from "earl"; + +import { ErrorsHandlingMode } from "../../../../llm/llmServices/llmService"; + +import { + mockChat, + proofsToGenerate, +} from "../../llmSpecificTestUtils/constants"; +import { subscribeToTrackMockEvents } from "../../llmSpecificTestUtils/eventsTracker"; +import { expectLogs } from "../../llmSpecificTestUtils/expectLogs"; +import { + MockLLMModelParams, + MockLLMService, +} from "../../llmSpecificTestUtils/mockLLMService"; +import { testFailedGenerationCompletely } from "../../llmSpecificTestUtils/testFailedGeneration"; +import { withMockLLMService } from "../../llmSpecificTestUtils/withMockLLMService"; + +suite("[LLMService] Test `generateFromChat`", () => { + [ + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS, + ErrorsHandlingMode.RETHROW_ERRORS, + ].forEach((errorsHandlingMode) => { + test(`Test successful generation: ${errorsHandlingMode}`, async () => { + await withMockLLMService( + async (mockService, basicMockParams, testEventLogger) => { + const eventsTracker = subscribeToTrackMockEvents( + testEventLogger, + mockService, + basicMockParams.modelId, + mockChat + ); + + const generatedProofs = await mockService.generateFromChat( + mockChat, + basicMockParams, + proofsToGenerate.length, + errorsHandlingMode + ); + expect(generatedProofs).toEqual(proofsToGenerate); + + expect(eventsTracker).toEqual({ + mockEventsN: 1, + successfulRequestEventsN: 1, + failedRequestEventsN: 0, + }); + expectLogs([{ status: "SUCCESS" }], mockService); + } + ); + }); + }); + + async function generateFromChat( + mockService: MockLLMService, + mockParams: MockLLMModelParams, + errorsHandlingMode: ErrorsHandlingMode + ): Promise { + return mockService.generateFromChat( + mockChat, + mockParams, + proofsToGenerate.length, + errorsHandlingMode + ); + } + + testFailedGenerationCompletely(generateFromChat, { + expectedChatOfMockEvent: mockChat, + proofsToGenerate: proofsToGenerate, + }); +}); diff --git a/src/test/llm/llmServices/llmService/generateProofIntegrationTesting.test.ts b/src/test/llm/llmServices/llmService/generateProofIntegrationTesting.test.ts new file mode 100644 index 00000000..a76b3be2 --- /dev/null +++ b/src/test/llm/llmServices/llmService/generateProofIntegrationTesting.test.ts @@ -0,0 +1,407 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../../llm/llmServiceErrors"; +import { ErrorsHandlingMode } from "../../../../llm/llmServices/llmService"; + +import { EventLogger } from "../../../../logging/eventLogger"; +import { + mockProofGenerationContext, + proofsToGenerate, +} from "../../llmSpecificTestUtils/constants"; +import { + MockEventsTracker, + subscribeToTrackMockEvents, +} from "../../llmSpecificTestUtils/eventsTracker"; +import { + expectGeneratedProof, + toProofVersion, +} from "../../llmSpecificTestUtils/expectGeneratedProof"; +import { + ExpectedRecord, + expectLogs, +} from "../../llmSpecificTestUtils/expectLogs"; +import { + MockLLMGeneratedProof, + MockLLMModelParams, + MockLLMService, +} from "../../llmSpecificTestUtils/mockLLMService"; +import { + testFailedGenerationCompletely, + testFailureAtChatBuilding, +} from "../../llmSpecificTestUtils/testFailedGeneration"; +import { enhanceMockParams } from "../../llmSpecificTestUtils/transformParams"; +import { withMockLLMService } from "../../llmSpecificTestUtils/withMockLLMService"; + +/* + * Note: fitting context (theorems, diagnostics) into chats is tested in + * `chatFactory.test.ts` and `chatTokensFitter.test.ts`. + * Therefore, in this suite testing of context-fitting will be omitted. + */ +suite("[LLMService] Integration testing of `generateProof`", () => { + test("Test success, 1 round and default settings", async () => { + await withMockLLMService( + async (mockService, basicMockParams, testEventLogger) => { + const eventsTracker = subscribeToTrackMockEvents( + testEventLogger, + mockService, + basicMockParams.modelId + ); + + const generatedProofs = await mockService.generateProof( + mockProofGenerationContext, + basicMockParams, + proofsToGenerate.length + ); + + expect(generatedProofs).toHaveLength(proofsToGenerate.length); + for (let i = 0; i < generatedProofs.length; i++) { + expectGeneratedProof(generatedProofs[i], { + proof: proofsToGenerate[i], + proofVersions: [toProofVersion(proofsToGenerate[i])], + versionNumber: 1, + canBeFixed: false, + }); + } + + expect(eventsTracker).toEqual({ + mockEventsN: 1, + successfulRequestEventsN: 1, + failedRequestEventsN: 0, + }); + expectLogs([{ status: "SUCCESS" }], mockService); + } + ); + }); + + async function generateProof( + mockService: MockLLMService, + mockParams: MockLLMModelParams, + errorsHandlingMode: ErrorsHandlingMode + ): Promise { + return ( + await mockService.generateProof( + mockProofGenerationContext, + mockParams, + proofsToGenerate.length, + errorsHandlingMode + ) + ).map((generatedProof) => generatedProof.proof()); + } + + testFailedGenerationCompletely(generateProof, { + proofsToGenerate: proofsToGenerate, + }); + + testFailureAtChatBuilding(generateProof, { + proofsToGenerate: proofsToGenerate, + }); + + test("Test successful 2-round generation, default settings", async () => { + await withMockLLMService( + async (mockService, basicMockParams, testEventLogger) => { + const eventsTracker = subscribeToTrackMockEvents( + testEventLogger, + mockService, + basicMockParams.modelId + ); + + const withFixesMockParams = enhanceMockParams(basicMockParams, { + maxRoundsNumber: 2, + defaultProofFixChoices: 3, + + // makes MockLLMService generate `Fixed.` proofs if is found in sent chat + proofFixPrompt: MockLLMService.proofFixPrompt, + }); + + const generatedProofs = await mockService.generateProof( + mockProofGenerationContext, + withFixesMockParams, + proofsToGenerate.length + ); + expect(generatedProofs).toHaveLength(proofsToGenerate.length); + + const diagnostic = "Proof is incorrect..."; + for (const generatedProof of generatedProofs) { + expect(generatedProof.canBeFixed()).toBeTruthy(); + const fixedGeneratedProofs = + await generatedProof.fixProof(diagnostic); + expect(fixedGeneratedProofs).toHaveLength( + withFixesMockParams.multiroundProfile + .defaultProofFixChoices + ); + + fixedGeneratedProofs.forEach((fixedGeneratedProof) => { + expectGeneratedProof(fixedGeneratedProof, { + proof: MockLLMService.fixedProofString, + proofVersions: [ + toProofVersion( + generatedProof.proof(), + diagnostic + ), + toProofVersion(MockLLMService.fixedProofString), + ], + versionNumber: 2, + canBeFixed: false, + }); + }); + } + + const generationsN = 1 + generatedProofs.length; + expect(eventsTracker).toEqual({ + mockEventsN: generationsN, + successfulRequestEventsN: generationsN, + failedRequestEventsN: 0, + }); + expectLogs( + new Array(generationsN).fill({ status: "SUCCESS" }), + mockService + ); + } + ); + }); + + function tossCoin(trueProbability: number): boolean { + return Math.random() < trueProbability; + } + + function throwErrorOnNextGeneration( + probability: number, + mockService: MockLLMService, + error: Error, + workerId: number + ): Error | undefined { + const coin = tossCoin(probability); + if (coin) { + mockService.throwErrorOnNextGeneration(error, workerId); + } + return coin ? error : undefined; + } + + function updateExpectations( + errorWasThrown: Error | undefined, + generatedProofs: MockLLMGeneratedProof[], + expectedProofsLength: number, + expectedEvents: MockEventsTracker, + expectedLogs?: ExpectedRecord[] + ) { + expectedEvents.mockEventsN += 1; + if (errorWasThrown !== undefined) { + expect(generatedProofs).toHaveLength(0); + expectedEvents.failedRequestEventsN += 1; + expectedLogs?.push({ + status: "FAILURE", + error: errorWasThrown, + }); + } else { + expect(generatedProofs).toHaveLength(expectedProofsLength); + expectedEvents.successfulRequestEventsN += 1; + expectedLogs?.push({ status: "SUCCESS" }); + } + } + + function checkExpectations( + actualEvents: MockEventsTracker, + expectedEvents: MockEventsTracker, + expectedLogs: ExpectedRecord[], + mockService: MockLLMService + ) { + expect(actualEvents).toEqual(expectedEvents); + expectLogs(expectedLogs, mockService); + } + + interface StressTestParams { + workersN: number; + iterationsPerWorker: number; + newProofsOnEachIteration: number; + proofFixChoices: number; + tryToFixProbability: number; + failedGenerationProbability: number; + } + + async function stressTest( + testParams: StressTestParams, + mockService: MockLLMService, + basicMockParams: MockLLMModelParams, + testEventLogger: EventLogger, + expectLogsAndCheckExpectations: boolean + ): Promise< + [MockEventsTracker, MockEventsTracker, ExpectedRecord[] | undefined] + > { + const actualEvents = subscribeToTrackMockEvents( + testEventLogger, + mockService, + basicMockParams.modelId + ); + const expectedEvents: MockEventsTracker = { + mockEventsN: 0, + successfulRequestEventsN: 0, + failedRequestEventsN: 0, + }; + const expectedLogs: ExpectedRecord[] | undefined = + expectLogsAndCheckExpectations ? [] : undefined; + + expect(testParams.newProofsOnEachIteration).toBeLessThanOrEqual( + basicMockParams.proofsToGenerate.length + ); + basicMockParams.multiroundProfile.defaultProofFixChoices = + testParams.proofFixChoices; + + const connectionError = Error("failed to reach server"); + const diagnostic = "Proof is incorrect."; + + const workers: Promise[] = []; + for (let w = 0; w < testParams.workersN; w++) { + const work = async () => { + const workerMockParams: MockLLMModelParams = { + ...basicMockParams, + workerId: w, + }; + + let toFixCandidates: MockLLMGeneratedProof[] = []; + for (let i = 0; i < testParams.iterationsPerWorker; i++) { + const throwError = throwErrorOnNextGeneration( + testParams.failedGenerationProbability, + mockService, + connectionError, + w + ); + const generatedProofs = await mockService.generateProof( + mockProofGenerationContext, + workerMockParams, + testParams.newProofsOnEachIteration + ); + updateExpectations( + throwError, + generatedProofs, + testParams.newProofsOnEachIteration, + expectedEvents, + expectedLogs + ); + if (expectedLogs !== undefined) { + checkExpectations( + actualEvents, + expectedEvents, + expectedLogs, + mockService + ); + } + + toFixCandidates = [toFixCandidates, generatedProofs] + .flat() + .filter((_generatedProof) => { + tossCoin(testParams.tryToFixProbability); + }); + + const newlyGeneratedProofs = []; + for (const generatedProofToFix of toFixCandidates) { + if (!generatedProofToFix.canBeFixed()) { + expect( + async () => + await generatedProofToFix.fixProof( + diagnostic + ) + ).toBeRejectedWith( + ConfigurationError, + "could not be fixed" + ); + } else { + const throwError = throwErrorOnNextGeneration( + testParams.failedGenerationProbability, + mockService, + connectionError, + w + ); + const fixedGeneratedProofs = + await generatedProofToFix.fixProof(diagnostic); + + updateExpectations( + throwError, + fixedGeneratedProofs, + basicMockParams.multiroundProfile + .defaultProofFixChoices, + expectedEvents, + expectedLogs + ); + if (expectedLogs !== undefined) { + checkExpectations( + actualEvents, + expectedEvents, + expectedLogs, + mockService + ); + } + newlyGeneratedProofs.push(...fixedGeneratedProofs); + } + } + toFixCandidates = newlyGeneratedProofs; + } + return "done"; + }; + workers.push(work()); + } + + const results = await Promise.all(workers); + expect(results).toEqual(new Array(testParams.workersN).fill("done")); + return [actualEvents, expectedEvents, expectedLogs]; + } + + test("Stress test with sync worker (multiround with random failures, default settings)", async () => { + await withMockLLMService( + async (mockService, basicMockParams, testEventLogger) => { + const [_actualEvents, _expectedEvents, _expectedLogs] = + await stressTest( + { + workersN: 1, + iterationsPerWorker: 1000, + newProofsOnEachIteration: 10, + proofFixChoices: 4, + tryToFixProbability: 0.5, + failedGenerationProbability: 0.5, + }, + mockService, + basicMockParams, + testEventLogger, + true + ); + } + ); + }).timeout(15000); + + test("Stress test with async workers (multiround with random failures, default settings)", async () => { + await withMockLLMService( + async (mockService, basicMockParams, testEventLogger) => { + const [actualEvents, expectedEvents, _undefined] = + await stressTest( + { + workersN: 10, + iterationsPerWorker: 100, + newProofsOnEachIteration: 10, + proofFixChoices: 4, + tryToFixProbability: 0.5, + failedGenerationProbability: 0.5, + }, + mockService, + basicMockParams, + testEventLogger, + false + ); + + expect(actualEvents).toEqual(expectedEvents); + + const logs = mockService.readGenerationsLogs(); + const successLogsN = logs.filter( + (record) => record.responseStatus === "SUCCESS" + ).length; + const failureLogsN = logs.filter( + (record) => record.responseStatus === "FAILURE" + ).length; + expect(successLogsN).toEqual( + expectedEvents.successfulRequestEventsN + ); + expect(failureLogsN).toEqual( + expectedEvents.failedRequestEventsN + ); + } + ); + }).timeout(5000); +}); diff --git a/src/test/llm/llmServices/llmService/generatedProof.test.ts b/src/test/llm/llmServices/llmService/generatedProof.test.ts new file mode 100644 index 00000000..9e6774ad --- /dev/null +++ b/src/test/llm/llmServices/llmService/generatedProof.test.ts @@ -0,0 +1,264 @@ +import { expect } from "earl"; + +import { AnalyzedChatHistory } from "../../../../llm/llmServices/chat"; +import { ErrorsHandlingMode } from "../../../../llm/llmServices/llmService"; + +import { + mockChat, + mockProofGenerationContext, + proofsToGenerate, +} from "../../llmSpecificTestUtils/constants"; +import { + expectGeneratedProof, + toProofVersion, +} from "../../llmSpecificTestUtils/expectGeneratedProof"; +import { + MockLLMGeneratedProof, + MockLLMModelParams, + MockLLMService, +} from "../../llmSpecificTestUtils/mockLLMService"; +import { testFailedGenerationCompletely } from "../../llmSpecificTestUtils/testFailedGeneration"; +import { enhanceMockParams } from "../../llmSpecificTestUtils/transformParams"; +import { withMockLLMService } from "../../llmSpecificTestUtils/withMockLLMService"; + +/* + * Note: fitting context (theorems, diagnostics) into chats is tested in + * `chatFactory.test.ts` and `chatTokensFitter.test.ts`. + * Therefore, in this suite testing of context-fitting will be omitted. + */ +suite("[LLMService] Test `GeneratedProof`", () => { + // the first initial proof and 3 new ones = at least 4 proofs to generate + expect(proofsToGenerate.length).toBeGreaterThanOrEqual(4); + + function transformChatToSkipProofs( + analyzedChat: AnalyzedChatHistory, + mockService: MockLLMService, + skipFirstNProofs: number + ): AnalyzedChatHistory { + return { + chat: mockService.transformChatToSkipFirstNProofs( + analyzedChat.chat, + skipFirstNProofs + ), + estimatedTokens: analyzedChat.estimatedTokens, + }; + } + + async function constructInitialGeneratedProof( + mockService: MockLLMService, + basicMockParams: MockLLMModelParams + ): Promise { + const unlimitedTokensWithFixesMockParams = enhanceMockParams( + basicMockParams, + { + maxRoundsNumber: 2, + + // will be overriden at calls + defaultProofFixChoices: 1, + + // makes MockLLMService generate `Fixed.` proofs if is found in sent chat + proofFixPrompt: MockLLMService.proofFixPrompt, + } + ); + const generatedProofs = await mockService.generateProof( + mockProofGenerationContext, + unlimitedTokensWithFixesMockParams, + 1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + expect(generatedProofs).toHaveLength(1); + return generatedProofs[0] as MockLLMGeneratedProof; + } + + test("Build initial version", async () => { + await withMockLLMService( + async (mockService, basicMockParams, _testEventLogger) => { + const initialGeneratedProof = + await constructInitialGeneratedProof( + mockService, + basicMockParams + ); + expectGeneratedProof(initialGeneratedProof, { + proof: proofsToGenerate[0], + versionNumber: 1, + proofVersions: [toProofVersion(proofsToGenerate[0])], + nextVersionCanBeGenerated: true, + canBeFixed: true, + }); + } + ); + }); + + async function testExtractsProof( + dirtyProof: string, + expectedExtractedProof: string + ): Promise { + return withMockLLMService( + async (mockService, basicMockParams, _testEventLogger) => { + const generatedProof = await constructInitialGeneratedProof( + mockService, + { + ...basicMockParams, + proofsToGenerate: [dirtyProof], + } + ); + expectGeneratedProof(generatedProof, { + proof: expectedExtractedProof, + versionNumber: 1, + proofVersions: [toProofVersion(expectedExtractedProof)], + }); + } + ); + } + + test("Correctly extracts proof from dirty input (when created)", async () => { + await testExtractsProof("auto.", "auto."); + await testExtractsProof("Proof. auto.", "Proof. auto."); + await testExtractsProof("auto. Qed.", "auto. Qed."); + await testExtractsProof("some text", "some text"); + + await testExtractsProof("Proof.auto.Qed.", "auto."); + await testExtractsProof("Proof.Qed.", ""); + + await testExtractsProof("Proof. auto. Qed.", "auto."); + await testExtractsProof("Proof.\nauto.\nQed.", "auto."); + await testExtractsProof("Proof.\n\tauto.\nQed.", "auto."); + await testExtractsProof("\tProof.\n\t\tauto.\n\tQed.", "auto."); + + await testExtractsProof("PrefixProof.auto.Qed.Suffix", "auto."); + await testExtractsProof( + "The following proof should solve your theorem:\n```Proof.\n\tauto.\nQed.```\nAsk me more questions, if you want to!", + "auto." + ); + + await testExtractsProof("Proof.auto.Qed. Proof.intros.Qed.", "auto."); + }); + + test("Mock multiround: generate next version, happy path", async () => { + await withMockLLMService( + async (mockService, basicMockParams, _testEventLogger) => { + const initialGeneratedProof = + await constructInitialGeneratedProof( + mockService, + basicMockParams + ); + + const newVersionChoices = 3; + const secondVersionGeneratedProofs = + await initialGeneratedProof.generateNextVersion( + transformChatToSkipProofs(mockChat, mockService, 1), + newVersionChoices, + ErrorsHandlingMode.RETHROW_ERRORS + ); + expect(secondVersionGeneratedProofs).toHaveLength( + newVersionChoices + ); + + // test that `proofVersions` of the initial proof didn't change + expect(initialGeneratedProof.proofVersions).toEqual([ + toProofVersion(proofsToGenerate[0]), + ]); + + for (let i = 0; i < newVersionChoices; i++) { + const expectedNewProof = proofsToGenerate[i + 1]; + expectGeneratedProof(secondVersionGeneratedProofs[i], { + proof: expectedNewProof, + versionNumber: 2, + proofVersions: [ + toProofVersion(proofsToGenerate[0]), + toProofVersion(expectedNewProof), + ], + nextVersionCanBeGenerated: false, // `maxRoundsNumber`: 2 + }); + } + } + ); + }); + + test("Fix proof, happy path", async () => { + await withMockLLMService( + async (mockService, basicMockParams, _testEventLogger) => { + const initialGeneratedProof = + await constructInitialGeneratedProof( + mockService, + basicMockParams + ); + + const fixedVersionChoices = 3; + const initialProofDiagnostic = `Proof \`${initialGeneratedProof.proof()}\` was incorrect...`; + const fixedGeneratedProofs = + await initialGeneratedProof.fixProof( + initialProofDiagnostic, + fixedVersionChoices, + ErrorsHandlingMode.RETHROW_ERRORS + ); + expect(fixedGeneratedProofs).toHaveLength(fixedVersionChoices); + + // test that `proofVersions` of the initial proof was updated correctly + expect(initialGeneratedProof.proofVersions).toEqual([ + toProofVersion(proofsToGenerate[0], initialProofDiagnostic), + ]); + + const expectedFixedProof = MockLLMService.fixedProofString; + fixedGeneratedProofs.forEach((fixedGeneratedProof) => { + expectGeneratedProof(fixedGeneratedProof, { + proof: expectedFixedProof, + versionNumber: 2, + proofVersions: [ + toProofVersion( + proofsToGenerate[0], + initialProofDiagnostic + ), + toProofVersion(expectedFixedProof), + ], + canBeFixed: false, + }); + }); + } + ); + }); + + async function fixProof( + _mockService: MockLLMService, + _mockParams: MockLLMModelParams, + errorsHandlingMode: ErrorsHandlingMode, + preparedData?: MockLLMGeneratedProof + ): Promise { + const initialGeneratedProof = preparedData; + if (initialGeneratedProof === undefined) { + throw Error( + `test is configured incorrectly: \`fixProof\` got "undefined" as \`preparedData\` instead of \`MockLLMGeneratedProof\`` + ); + } + const fixedGeneratedProofs = await initialGeneratedProof.fixProof( + "Proof was incorrect", + 1, + errorsHandlingMode + ); + + return fixedGeneratedProofs.map((generatedProof) => + generatedProof.proof() + ); + } + + async function prepareData( + mockService: MockLLMService, + basicMockParams: MockLLMModelParams + ): Promise { + const initialGeneratedProof = await constructInitialGeneratedProof( + mockService, + basicMockParams + ); + mockService.clearGenerationLogs(); + return initialGeneratedProof; + } + + testFailedGenerationCompletely( + fixProof, + { + proofsToGenerate: [MockLLMService.fixedProofString], + testTargetName: "Fix proof, failed generation", + }, + prepareData + ); +}); diff --git a/src/test/llm/llmServices/llmService/modelParamsResolvers.test.ts b/src/test/llm/llmServices/llmService/modelParamsResolvers.test.ts new file mode 100644 index 00000000..0d40d8fc --- /dev/null +++ b/src/test/llm/llmServices/llmService/modelParamsResolvers.test.ts @@ -0,0 +1,147 @@ +import { expect } from "earl"; + +import { + ModelParams, + MultiroundProfile, + modelParamsSchema, +} from "../../../../llm/llmServices/modelParams"; +import { + BasicModelParamsResolver, + defaultMultiroundProfile, + defaultSystemMessageContent, +} from "../../../../llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers"; +import { UserModelParams } from "../../../../llm/userModelParams"; + +import { withLLMService } from "../../../commonTestFunctions/withLLMService"; +import { testModelId } from "../../llmSpecificTestUtils/constants"; +import { + MockLLMModelParams, + MockLLMService, + MockLLMUserModelParams, +} from "../../llmSpecificTestUtils/mockLLMService"; +import { + ModelParamsAddOns, + UserModelParamsAddOns, +} from "../../llmSpecificTestUtils/modelParamsAddOns"; + +suite("[LLMService] Test model-params resolution", () => { + function testBasicResolverSucceeded( + testName: string, + inputParamsAddOns: UserModelParamsAddOns = {}, + expectedResolvedParamsAddOns: ModelParamsAddOns = {} + ) { + test(testName, () => { + const inputParams: UserModelParams = { + modelId: testModelId, + choices: 1, + // `systemPrompt` will be resolved with default + maxTokensToGenerate: 100, + tokensLimit: 1000, + multiroundProfile: { + proofFixChoices: 3, + // `maxRoundsNumber` and `proofFixPrompt` will be resolved with defaults + }, + ...inputParamsAddOns, + }; + const modelParamsResolver = new BasicModelParamsResolver( + modelParamsSchema, + "ModelParams" + ); + const resolutionResult = modelParamsResolver.resolve(inputParams); + + const expectedResolvedParams: ModelParams = { + modelId: testModelId, + systemPrompt: defaultSystemMessageContent, + maxTokensToGenerate: 100, + tokensLimit: 1000, + multiroundProfile: { + maxRoundsNumber: defaultMultiroundProfile.maxRoundsNumber, + defaultProofFixChoices: 3, + proofFixPrompt: defaultMultiroundProfile.proofFixPrompt, + } as MultiroundProfile, + defaultChoices: 1, + ...expectedResolvedParamsAddOns, + } as ModelParams; + expect(resolutionResult.resolved).toEqual(expectedResolvedParams); + }); + } + + testBasicResolverSucceeded( + "Test basic resolver: successfully resolves with defaults" + ); + + testBasicResolverSucceeded( + "Test basic resolver: resolves undefined `multiroundProfile`", + { + multiroundProfile: undefined, + }, + { + multiroundProfile: defaultMultiroundProfile, + } + ); + + test("Test basic resolver: reports failed parameters", () => { + const inputParams: UserModelParams = { + modelId: testModelId, + choices: undefined, // fail + systemPrompt: "Generate proof!", + maxTokensToGenerate: -1, // fail + tokensLimit: -1, // fail + multiroundProfile: { + maxRoundsNumber: -1, // fail + proofFixChoices: -1, // fail + proofFixPrompt: "Fix proof!", + }, + }; + const modelParamsResolver = new BasicModelParamsResolver( + modelParamsSchema, + "ModelParams" + ); + const resolutionResult = modelParamsResolver.resolve(inputParams); + + expect(resolutionResult.resolved).toBeNullish(); + const expectedNumberOfFailedParams = 5; + expect( + resolutionResult.resolutionLogs.filter( + (paramLog) => paramLog.isInvalidCause !== undefined + ) + ).toHaveLength(expectedNumberOfFailedParams); + }); + + test("Test resolution by LLMService", async () => { + await withLLMService(new MockLLMService(), async (mockService) => { + const unresolvedMockUserParams: MockLLMUserModelParams = { + modelId: testModelId, + systemPrompt: "This system prompt will be overriden by service", + maxTokensToGenerate: 100, + tokensLimit: 1000, + proofsToGenerate: ["auto.", "avto."], + }; + + /* + * `MockLLMService` parameters resolution does 4 changes to `inputParams`: + * - resolves undefined `workerId` to 0; + * - adds extra `resolvedWithMockLLMService: true` property; + * - overrides original `systemPrompt` with `this.systemPromptToOverrideWith`. + * - overrides original `choices` to `defaultChoices` with `proofsToGenerate.length`. + * Everything else should be resolved with defaults, if needed. + */ + const expectedResolvedMockParams = { + ...unresolvedMockUserParams, + multiroundProfile: defaultMultiroundProfile, + systemPrompt: MockLLMService.systemPromptToOverrideWith, + workerId: 0, + resolvedWithMockLLMService: true, + defaultChoices: 2, + } as MockLLMModelParams; + + const actualResolvedMockParams = mockService.resolveParameters( + unresolvedMockUserParams + ).resolved; + + expect(actualResolvedMockParams).toEqual( + expectedResolvedMockParams + ); + }); + }); +}); diff --git a/src/test/llm/llmServices/lmStudioService.test.ts b/src/test/llm/llmServices/lmStudioService.test.ts new file mode 100644 index 00000000..2fc0edaa --- /dev/null +++ b/src/test/llm/llmServices/lmStudioService.test.ts @@ -0,0 +1,117 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../llm/llmServiceErrors"; +import { ErrorsHandlingMode } from "../../../llm/llmServices/llmService"; +import { LMStudioService } from "../../../llm/llmServices/lmStudio/lmStudioService"; +import { LMStudioModelParams } from "../../../llm/llmServices/modelParams"; +import { defaultSystemMessageContent } from "../../../llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers"; +import { LMStudioUserModelParams } from "../../../llm/userModelParams"; + +import { testIf } from "../../commonTestFunctions/conditionalTest"; +import { + withLLMService, + withLLMServiceAndParams, +} from "../../commonTestFunctions/withLLMService"; +import { + mockProofGenerationContext, + testModelId, +} from "../llmSpecificTestUtils/constants"; +import { testLLMServiceCompletesAdmitFromFile } from "../llmSpecificTestUtils/testAdmitCompletion"; +import { + defaultUserMultiroundProfile, + testResolveParametersFailsWithSingleCause, + testResolveValidCompleteParameters, +} from "../llmSpecificTestUtils/testResolveParameters"; + +suite("[LLMService] Test `LMStudioService`", function () { + const lmStudioPort = process.env.LMSTUDIO_PORT; + const choices = 15; + const inputFile = ["small_document.v"]; + + const requiredInputParamsTemplate = { + modelId: testModelId, + temperature: 1, + choices: choices, + maxTokensToGenerate: 2000, + tokensLimit: 4000, + }; + + testIf( + lmStudioPort !== undefined, + "`LMSTUDIO_PORT` is not specified", + this.title, + `Simple generation: 1 request, ${choices} choices`, + async () => { + const inputParams: LMStudioUserModelParams = { + ...requiredInputParamsTemplate, + port: parseInt(lmStudioPort!), + }; + const lmStudioService = new LMStudioService(); + await testLLMServiceCompletesAdmitFromFile( + lmStudioService, + inputParams, + inputFile, + choices + ); + } + )?.timeout(30000); + + test("Test `resolveParameters` reads & accepts valid params", async () => { + const inputParams: LMStudioUserModelParams = { + ...requiredInputParamsTemplate, + port: 1234, + }; + await withLLMService(new LMStudioService(), async (lmStudioService) => { + testResolveValidCompleteParameters(lmStudioService, inputParams); + testResolveValidCompleteParameters( + lmStudioService, + { + ...inputParams, + systemPrompt: defaultSystemMessageContent, + multiroundProfile: defaultUserMultiroundProfile, + }, + true + ); + }); + }); + + test("Test `resolveParameters` validates LMStudio-extended params (`port`)", async () => { + const inputParams: LMStudioUserModelParams = { + ...requiredInputParamsTemplate, + port: 1234, + }; + await withLLMService(new LMStudioService(), async (lmStudioService) => { + // port !in [0, 65535] + testResolveParametersFailsWithSingleCause( + lmStudioService, + { + ...inputParams, + port: 100000, + }, + "port" + ); + }); + }); + + test("Test `generateProof` throws on invalid `choices`", async () => { + const inputParams: LMStudioUserModelParams = { + ...requiredInputParamsTemplate, + port: 1234, + }; + await withLLMServiceAndParams( + new LMStudioService(), + inputParams, + async (lmStudioService, resolvedParams: LMStudioModelParams) => { + // non-positive choices + expect(async () => { + await lmStudioService.generateProof( + mockProofGenerationContext, + resolvedParams, + -1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith(ConfigurationError, "choices"); + } + ); + }); +}); diff --git a/src/test/llm/llmServices/openAiService.test.ts b/src/test/llm/llmServices/openAiService.test.ts new file mode 100644 index 00000000..d8b12e3d --- /dev/null +++ b/src/test/llm/llmServices/openAiService.test.ts @@ -0,0 +1,268 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../llm/llmServiceErrors"; +import { ErrorsHandlingMode } from "../../../llm/llmServices/llmService"; +import { OpenAiModelParams } from "../../../llm/llmServices/modelParams"; +import { OpenAiService } from "../../../llm/llmServices/openai/openAiService"; +import { defaultSystemMessageContent } from "../../../llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers"; +import { OpenAiUserModelParams } from "../../../llm/userModelParams"; + +import { testIf } from "../../commonTestFunctions/conditionalTest"; +import { + withLLMService, + withLLMServiceAndParams, +} from "../../commonTestFunctions/withLLMService"; +import { + gptTurboModelName, + mockProofGenerationContext, + testModelId, +} from "../llmSpecificTestUtils/constants"; +import { testLLMServiceCompletesAdmitFromFile } from "../llmSpecificTestUtils/testAdmitCompletion"; +import { + defaultUserMultiroundProfile, + testResolveParametersFailsWithSingleCause, + testResolveValidCompleteParameters, +} from "../llmSpecificTestUtils/testResolveParameters"; + +suite("[LLMService] Test `OpenAiService`", function () { + const apiKey = process.env.OPENAI_API_KEY; + const choices = 15; + const inputFile = ["small_document.v"]; + + const requiredInputParamsTemplate = { + modelId: testModelId, + modelName: gptTurboModelName, + temperature: 1, + choices: choices, + }; + + testIf( + apiKey !== undefined, + "`OPENAI_API_KEY` is not specified", + this.title, + `Simple generation: 1 request, ${choices} choices`, + async () => { + const inputParams: OpenAiUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: apiKey!, + }; + const openAiService = new OpenAiService(); + await testLLMServiceCompletesAdmitFromFile( + openAiService, + inputParams, + inputFile, + choices + ); + } + )?.timeout(5000); + + test("Test `resolveParameters` reads & accepts valid params", async () => { + const inputParams: OpenAiUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + }; + await withLLMService(new OpenAiService(), async (openAiService) => { + testResolveValidCompleteParameters(openAiService, inputParams); + testResolveValidCompleteParameters( + openAiService, + { + ...inputParams, + systemPrompt: defaultSystemMessageContent, + maxTokensToGenerate: 2000, + tokensLimit: 4000, + multiroundProfile: defaultUserMultiroundProfile, + }, + true + ); + }); + }); + + function testResolvesTokensWithDefault( + modelName: string, + inputTokensLimit: number | undefined, + expectedTokensLimit: number, + expectedMaxTokensToGenerate: number + ) { + const withDefinedTokensLimit = + inputTokensLimit === undefined + ? "" + : ", defined input `tokensLimit`"; + test(`Test \`resolveParameters\` resolves tokens with defaults: "${modelName}${withDefinedTokensLimit}"`, async () => { + const inputParams: OpenAiUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + modelName: modelName, + tokensLimit: inputTokensLimit, + }; + await withLLMService(new OpenAiService(), async (openAiService) => { + const resolutionResult = + openAiService.resolveParameters(inputParams); + expect(resolutionResult.resolved).not.toBeNullish(); + expect(resolutionResult.resolved!.tokensLimit).toEqual( + expectedTokensLimit + ); + expect(resolutionResult.resolved!.maxTokensToGenerate).toEqual( + expectedMaxTokensToGenerate + ); + // check it was resolution with default indeed + expect( + resolutionResult.resolutionLogs.find( + (paramLog) => + paramLog.inputParamName === "maxTokensToGenerate" + )?.resolvedWithDefault.wasPerformed + ).toBeTruthy(); + if (inputTokensLimit === undefined) { + expect( + resolutionResult.resolutionLogs.find( + (paramLog) => + paramLog.inputParamName === "tokensLimit" + )?.resolvedWithDefault.wasPerformed + ).toBeTruthy(); + } + }); + }); + } + + ( + [ + ["gpt-3.5-turbo-0301", undefined, 4096, 2048], + ["gpt-3.5-turbo-0125", undefined, 16_385, 4096], + ["gpt-4-32k-0314", undefined, 32_768, 4096], + ["gpt-3.5-turbo-0301", 3000, 3000, 1500], + ] as [string, number | undefined, number, number][] + ).forEach( + ([ + modelName, + inputTokensLimit, + expectedTokensLimit, + expectedMaxTokensToGenerate, + ]) => { + testResolvesTokensWithDefault( + modelName, + inputTokensLimit, + expectedTokensLimit, + expectedMaxTokensToGenerate + ); + } + ); + + test("Test `resolveParameters` validates OpenAI-extended params (`temperature`) & tokens params", async () => { + const inputParams: OpenAiUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + }; + await withLLMService(new OpenAiService(), async (openAiService) => { + // `temperature` !in [0, 2] + testResolveParametersFailsWithSingleCause( + openAiService, + { + ...inputParams, + temperature: 5, + }, + "temperature" + ); + + // `maxTokensToGenerate` > known `maxTokensToGenerate` for the "gpt-3.5-turbo-0301" model + testResolveParametersFailsWithSingleCause( + openAiService, + { + ...inputParams, + modelName: "gpt-3.5-turbo-0301", + maxTokensToGenerate: 5000, + }, + "maxTokensToGenerate" + ); + + // `tokensLimit` > known `tokensLimit` for the "gpt-3.5-turbo-0301" model + testResolveParametersFailsWithSingleCause( + openAiService, + { + ...inputParams, + modelName: "gpt-3.5-turbo-0301", + tokensLimit: 5000, + }, + "tokensLimit" + ); + }); + }); + + test("Test `generateProof` throws on invalid configurations, ", async () => { + const inputParams: OpenAiUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: "undefined", + }; + await withLLMServiceAndParams( + new OpenAiService(), + inputParams, + async (openAiService, resolvedParams: OpenAiModelParams) => { + // non-positive choices + expect(async () => { + await openAiService.generateProof( + mockProofGenerationContext, + resolvedParams, + -1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith(ConfigurationError, "choices"); + + // incorrect api key + expect(async () => { + await openAiService.generateProof( + mockProofGenerationContext, + resolvedParams, + 1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith(ConfigurationError, "api key"); + } + ); + }); + + testIf( + apiKey !== undefined, + "`OPENAI_API_KEY` is not specified", + this.title, + "Test `generateProof` throws on invalid configurations, ", + async () => { + const inputParams: OpenAiUserModelParams = { + ...requiredInputParamsTemplate, + apiKey: apiKey!, + }; + await withLLMServiceAndParams( + new OpenAiService(), + inputParams, + async (openAiService, resolvedParams) => { + // unknown model name + expect(async () => { + await openAiService.generateProof( + mockProofGenerationContext, + { + ...resolvedParams, + modelName: "unknown", + } as OpenAiModelParams, + 1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith(ConfigurationError, "model name"); + + // context length exceeded (requested too many tokens for the completion) + expect(async () => { + await openAiService.generateProof( + mockProofGenerationContext, + { + ...resolvedParams, + maxTokensToGenerate: 500_000, + tokensLimit: 1_000_000, + } as OpenAiModelParams, + 1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith( + ConfigurationError, + "`tokensLimit` and `maxTokensToGenerate`" + ); + } + ); + } + ); +}); diff --git a/src/test/llm/llmServices/predefinedProofsService.test.ts b/src/test/llm/llmServices/predefinedProofsService.test.ts new file mode 100644 index 00000000..a87a41a1 --- /dev/null +++ b/src/test/llm/llmServices/predefinedProofsService.test.ts @@ -0,0 +1,304 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../llm/llmServiceErrors"; +import { ErrorsHandlingMode } from "../../../llm/llmServices/llmService"; +import { PredefinedProofsModelParams } from "../../../llm/llmServices/modelParams"; +import { PredefinedProofsService } from "../../../llm/llmServices/predefinedProofs/predefinedProofsService"; +import { timeZero } from "../../../llm/llmServices/utils/time"; +import { ProofGenerationContext } from "../../../llm/proofGenerationContext"; +import { PredefinedProofsUserModelParams } from "../../../llm/userModelParams"; + +import { EventLogger } from "../../../logging/eventLogger"; +import { delay } from "../../commonTestFunctions/delay"; +import { resolveParametersOrThrow } from "../../commonTestFunctions/resolveOrThrow"; +import { withLLMService } from "../../commonTestFunctions/withLLMService"; +import { testModelId } from "../llmSpecificTestUtils/constants"; +import { + EventsTracker, + subscribeToTrackEvents, +} from "../llmSpecificTestUtils/eventsTracker"; +import { expectLogs } from "../llmSpecificTestUtils/expectLogs"; +import { testLLMServiceCompletesAdmitFromFile } from "../llmSpecificTestUtils/testAdmitCompletion"; +import { + testResolveParametersFailsWithSingleCause, + testResolveValidCompleteParameters, +} from "../llmSpecificTestUtils/testResolveParameters"; + +suite("[LLMService] Test `PredefinedProofsService`", function () { + const simpleTactics = ["auto.", "intros.", "reflexivity."]; + const inputParams: PredefinedProofsUserModelParams = { + modelId: testModelId, + tactics: simpleTactics, + }; + const proofGenerationContext: ProofGenerationContext = { + completionTarget: "could be anything", + contextTheorems: [], + }; + + async function withPredefinedProofsService( + block: ( + predefinedProofsService: PredefinedProofsService, + testEventLogger: EventLogger + ) => Promise + ) { + const testEventLogger = new EventLogger(); + return withLLMService( + new PredefinedProofsService(testEventLogger, true), + async (predefinedProofsService) => { + return block(predefinedProofsService, testEventLogger); + } + ); + } + + const choices = simpleTactics.length; + const inputFile = ["small_document.v"]; + + test("Simple generation: prove with `auto.`", async () => { + const predefinedProofsService = new PredefinedProofsService(); + await testLLMServiceCompletesAdmitFromFile( + predefinedProofsService, + inputParams, + inputFile, + choices + ); + }); + + [ + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS, + ErrorsHandlingMode.RETHROW_ERRORS, + ].forEach((errorsHandlingMode) => { + test(`Test generation logging: ${errorsHandlingMode}`, async () => { + await withPredefinedProofsService( + async (predefinedProofsService, testEventLogger) => { + const eventsTracker = subscribeToTrackEvents( + testEventLogger, + predefinedProofsService, + inputParams.modelId + ); + const resolvedParams = resolveParametersOrThrow( + predefinedProofsService, + inputParams + ); + + // failed generation + try { + await predefinedProofsService.generateProof( + proofGenerationContext, + resolvedParams, + resolvedParams.tactics.length + 1, + errorsHandlingMode + ); + } catch (e) { + expect(errorsHandlingMode).toEqual( + ErrorsHandlingMode.RETHROW_ERRORS + ); + const error = e as ConfigurationError; + expect(error).toBeTruthy(); + } + + const expectedEvents: EventsTracker = { + successfulRequestEventsN: 0, + failedRequestEventsN: + errorsHandlingMode === + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ? 1 + : 0, + }; + expect(eventsTracker).toEqual(expectedEvents); + + // ConfigurationError should not be logged! + expectLogs([], predefinedProofsService); + + // successful generation + const generatedProofs = + await predefinedProofsService.generateProof( + proofGenerationContext, + resolvedParams, + resolvedParams.tactics.length + ); + expect(generatedProofs).toHaveLength( + resolvedParams.tactics.length + ); + + expectedEvents.successfulRequestEventsN += 1; + expect(eventsTracker).toEqual(expectedEvents); + expectLogs( + [{ status: "SUCCESS" }], + predefinedProofsService + ); + } + ); + }); + }); + + test("Test `resolveParameters` reads & accepts valid params", async () => { + await withPredefinedProofsService(async (predefinedProofsService) => { + testResolveValidCompleteParameters( + predefinedProofsService, + inputParams + ); + }); + }); + + test("Test `resolveParameters` validates PredefinedProofs-extended params (`tactics`)", async () => { + await withPredefinedProofsService(async (predefinedProofsService) => { + testResolveParametersFailsWithSingleCause( + predefinedProofsService, + { + ...inputParams, + tactics: [], + }, + "tactics" + ); + }); + }); + + test("Test `resolveParameters` overrides params correctly", async () => { + await withPredefinedProofsService(async (predefinedProofsService) => { + const resolutionResult = predefinedProofsService.resolveParameters({ + ...inputParams, + choices: 1, + systemPrompt: "asking for something", + maxTokensToGenerate: 2000, + tokensLimit: 4000, + multiroundProfile: { + maxRoundsNumber: 10, + proofFixChoices: 5, + proofFixPrompt: "asking for more of something", + }, + }); + + // first, verify all params were read correctly + for (const paramLog of resolutionResult.resolutionLogs) { + expect(paramLog.isInvalidCause).toBeNullish(); + expect(paramLog.inputReadCorrectly.wasPerformed).toBeTruthy(); + // expect(paramLog.overriden).toBeTruthy(); // is not true for mock overrides + expect(paramLog.resolvedWithDefault.wasPerformed).toBeFalsy(); + } + + expect(resolutionResult.resolved).toEqual({ + modelId: testModelId, + tactics: simpleTactics, + systemPrompt: "", + maxTokensToGenerate: Math.max( + 0, + ...simpleTactics.map((tactic) => tactic.length) + ), + tokensLimit: Number.MAX_SAFE_INTEGER, + multiroundProfile: { + maxRoundsNumber: 1, + defaultProofFixChoices: 0, + proofFixPrompt: "", + }, + defaultChoices: simpleTactics.length, + }); + }); + }); + + test("Test `generateProof` throws on invalid `choices`", async () => { + await withPredefinedProofsService(async (predefinedProofsService) => { + const resolvedParams = resolveParametersOrThrow( + predefinedProofsService, + inputParams + ); + + // non-positive choices + expect(async () => { + await predefinedProofsService.generateProof( + proofGenerationContext, + resolvedParams, + -1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith(ConfigurationError, "choices"); + + // choices > tactics.length + expect(async () => { + await predefinedProofsService.generateProof( + proofGenerationContext, + resolvedParams, + resolvedParams.tactics.length + 1, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith(ConfigurationError, "choices"); + }); + }); + + test("Test chat-related features throw", async () => { + await withPredefinedProofsService(async (predefinedProofsService) => { + const resolvedParams = resolveParametersOrThrow( + predefinedProofsService, + inputParams + ); + expect(async () => { + await predefinedProofsService.generateFromChat( + { + chat: [], + estimatedTokens: { + messagesTokens: 0, + maxTokensToGenerate: 0, + maxTokensInTotal: 0, + }, + }, + resolvedParams, + choices, + ErrorsHandlingMode.RETHROW_ERRORS + ); + }).toBeRejectedWith( + ConfigurationError, + "does not support generation from chat" + ); + + const [generatedProof] = + await predefinedProofsService.generateProof( + proofGenerationContext, + resolvedParams, + 1 + ); + expect(generatedProof.canBeFixed()).toBeFalsy(); + expect( + async () => + await generatedProof.fixProof( + "pretend to be diagnostic", + 3, + ErrorsHandlingMode.RETHROW_ERRORS + ) + ).toBeRejectedWith(ConfigurationError, "cannot be fixed"); + }); + }); + + test("Test time to become available is zero", async () => { + await withPredefinedProofsService(async (predefinedProofsService) => { + const resolvedParams = resolveParametersOrThrow( + predefinedProofsService, + inputParams + ); + const cursedParams: PredefinedProofsModelParams = { + ...resolvedParams, + tactics: [ + "auto.", + () => { + throw Error("a curse"); + }, + ] as any[], + }; + await predefinedProofsService.generateProof( + proofGenerationContext, + cursedParams, + cursedParams.tactics.length, + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ); + await delay(4000); + await predefinedProofsService.generateProof( + proofGenerationContext, + cursedParams, + cursedParams.tactics.length, + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ); + // despite 2 failures with >= 4 secs interval, should be available right now + expect( + predefinedProofsService.estimateTimeToBecomeAvailable() + ).toEqual(timeZero); + }); + }).timeout(6000); +}); diff --git a/src/test/llm/llmServices/utils/chatFactory.test.ts b/src/test/llm/llmServices/utils/chatFactory.test.ts new file mode 100644 index 00000000..2aab4c5e --- /dev/null +++ b/src/test/llm/llmServices/utils/chatFactory.test.ts @@ -0,0 +1,508 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../../llm/llmServiceErrors"; +import { ChatHistory, ChatMessage } from "../../../../llm/llmServices/chat"; +import { ProofVersion } from "../../../../llm/llmServices/llmService"; +import { ModelParams } from "../../../../llm/llmServices/modelParams"; +import { + buildChat, + buildPreviousProofVersionsChat, + buildProofFixChat, + buildProofGenerationChat, + buildTheoremsChat, + createProofFixMessage, + validateChat, +} from "../../../../llm/llmServices/utils/chatFactory"; +import { + ItemizedChat, + UserAssistantChatItem, + chatItemToContent, + itemizedChatToHistory, +} from "../../../../llm/llmServices/utils/chatUtils"; +import { ProofGenerationContext } from "../../../../llm/proofGenerationContext"; + +import { Theorem } from "../../../../coqParser/parsedTypes"; +import { parseTheoremsFromCoqFile } from "../../../commonTestFunctions/coqFileParser"; +import { + approxCalculateTokens, + calculateTokensViaTikToken, +} from "../../llmSpecificTestUtils/calculateTokens"; +import { + gptTurboModelName, + testModelId, +} from "../../llmSpecificTestUtils/constants"; + +/* + * Note: if in the future some of the tests will act against experiments with chats, + * feel free to make them simplier. So far, they just check the current specification. + */ +suite("[LLMService-s utils] Building chats test", () => { + async function readTheorems(): Promise { + const theorems = await parseTheoremsFromCoqFile([ + "build_chat_theorems.v", + ]); + expect(theorems).toHaveLength(3); + return theorems; + } + + interface TestMessages { + systemMessage: ChatMessage; + + // user messages + plusTheoremStatement: ChatMessage; + plusAssocTheoremStatement: ChatMessage; + theoremToCompleteStatement: ChatMessage; + + // assistant messages + plusTheoremProof: ChatMessage; + plusAssocTheoremProof: ChatMessage; + + proofGenerationChat: ChatHistory; + } + + interface TestTheorems { + plusTheorem: Theorem; + plusAssocTheorem: Theorem; + theoremToComplete: Theorem; + } + + async function buildTestData(): Promise<[TestTheorems, TestMessages]> { + const [plusTheorem, plusAssocTheorem, theoremToComplete] = + await readTheorems(); + expect(plusTheorem.proof).toBeTruthy(); + expect(plusAssocTheorem.proof).toBeTruthy(); + + const messages = { + systemMessage: { + role: "system", + content: "Generate proofs in Coq!", + } as ChatMessage, + plusTheoremStatement: { + role: "user", + content: plusTheorem.statement, + } as ChatMessage, + plusAssocTheoremStatement: { + role: "user", + content: plusAssocTheorem.statement, + } as ChatMessage, + theoremToCompleteStatement: { + role: "user", + content: theoremToComplete.statement, + } as ChatMessage, + + plusTheoremProof: { + role: "assistant", + content: plusTheorem.proof!.onlyText(), + } as ChatMessage, + plusAssocTheoremProof: { + role: "assistant", + content: plusAssocTheorem.proof!.onlyText(), + } as ChatMessage, + }; + + return [ + { + plusTheorem: plusTheorem, + plusAssocTheorem: plusAssocTheorem, + theoremToComplete: theoremToComplete, + }, + { + ...messages, + proofGenerationChat: [ + messages.systemMessage, + messages.plusTheoremStatement, + messages.plusTheoremProof, + messages.plusAssocTheoremStatement, + messages.plusAssocTheoremProof, + messages.theoremToCompleteStatement, + ], + }, + ]; + } + + const misspelledProof: ProofVersion = { + proof: "something???", + diagnostic: "Bad input...", + }; + const incorrectProof: ProofVersion = { + proof: "auto.", + diagnostic: "It does not finish proof...", + }; + + test("Test `validateChat`: valid chats", async () => { + const [_, messages] = await buildTestData(); + + const onlySystemMessageChat: ChatHistory = [messages.systemMessage]; + expect(validateChat(onlySystemMessageChat)).toEqual([true, "ok"]); + + const oneUserRequestChat: ChatHistory = [ + messages.systemMessage, + messages.theoremToCompleteStatement, + ]; + expect(validateChat(oneUserRequestChat)).toEqual([true, "ok"]); + + expect(validateChat(messages.proofGenerationChat)).toEqual([ + true, + "ok", + ]); + }); + + test("Test `validateChat`: invalid chats", async () => { + const [_, messages] = await buildTestData(); + + expect(validateChat([])).toEqual([ + false, + "no system message at the chat start", + ]); + expect(validateChat([messages.theoremToCompleteStatement])).toEqual([ + false, + "no system message at the chat start", + ]); + expect( + validateChat([ + messages.systemMessage, + messages.plusTheoremStatement, + messages.systemMessage, + messages.theoremToCompleteStatement, + ]) + ).toEqual([false, "several system messages found"]); + expect( + validateChat([ + messages.systemMessage, + messages.plusTheoremStatement, + messages.plusTheoremStatement, + ]) + ).toEqual([false, "two identical roles in a row"]); + expect( + validateChat([ + messages.systemMessage, + messages.plusTheoremStatement, + messages.plusTheoremProof, + ]) + ).toEqual([ + false, + "last message in the chat should be authored either by `user` or by `system`", + ]); + }); + + test("Test `buildChat`", async () => { + const [_, messages] = await buildTestData(); + + expect(buildChat(messages.proofGenerationChat)).toEqual( + messages.proofGenerationChat + ); + expect( + buildChat( + messages.systemMessage, + messages.plusTheoremStatement, + messages.plusTheoremProof, + messages.plusAssocTheoremStatement, + messages.plusAssocTheoremProof, + messages.theoremToCompleteStatement + ) + ).toEqual(messages.proofGenerationChat); + expect( + buildChat( + buildChat(messages.systemMessage), + [messages.plusTheoremStatement, messages.plusTheoremProof], + messages.theoremToCompleteStatement + ) + ).toEqual([ + messages.systemMessage, + messages.plusTheoremStatement, + messages.plusTheoremProof, + messages.theoremToCompleteStatement, + ]); + + expect(() => + buildChat(messages.systemMessage, messages.systemMessage) + ).toThrow(ConfigurationError, "chat is invalid"); + }); + + test("Test chat-item wrappers", async () => { + const [_, messages] = await buildTestData(); + + const plusTheorem: UserAssistantChatItem = { + userMessage: messages.plusTheoremStatement.content, + assistantMessage: messages.plusTheoremProof.content, + }; + const plusAssocTheorem: UserAssistantChatItem = { + userMessage: messages.plusAssocTheoremStatement.content, + assistantMessage: messages.plusAssocTheoremProof.content, + }; + const itemizedHistory: ItemizedChat = [plusTheorem, plusAssocTheorem]; + + expect(chatItemToContent(plusTheorem)).toEqual([ + messages.plusTheoremStatement.content, + messages.plusTheoremProof.content, + ]); + expect(itemizedChatToHistory(itemizedHistory, true)).toEqual([ + messages.plusTheoremStatement, + messages.plusTheoremProof, + messages.plusAssocTheoremStatement, + messages.plusAssocTheoremProof, + ]); + expect(itemizedChatToHistory(itemizedHistory, false)).toEqual([ + messages.plusTheoremProof, + messages.plusTheoremStatement, + messages.plusAssocTheoremProof, + messages.plusAssocTheoremStatement, + ]); + }); + + test("Test theorems chat builder", async () => { + const [theorems, messages] = await buildTestData(); + const builtTheoremsChat = buildTheoremsChat([ + theorems.plusTheorem, + theorems.plusAssocTheorem, + ]); + expect(builtTheoremsChat).toEqual([ + messages.plusTheoremStatement, + messages.plusTheoremProof, + messages.plusAssocTheoremStatement, + messages.plusAssocTheoremProof, + ]); + }); + + function proofVersionToChat(proofVersion: ProofVersion): ChatHistory { + return [ + { role: "assistant", content: proofVersion.proof }, + { + role: "user", + content: `Proof is invalid, compiler diagnostic: ${proofVersion.diagnostic}`, + }, + ]; + } + + test("Test previous-proof-versions chat builder", () => { + const builtProofVersionsChat = buildPreviousProofVersionsChat([ + misspelledProof, + incorrectProof, + ]); + expect(builtProofVersionsChat).toEqual([ + ...proofVersionToChat(misspelledProof), + ...proofVersionToChat(incorrectProof), + ]); + }); + + function buildUnlimitedTokensModel( + messages: TestMessages, + modelName?: string + ): ModelParams { + const unlimitedTokensModelParams = { + modelId: testModelId, + systemPrompt: messages.systemMessage.content, + maxTokensToGenerate: 100, + tokensLimit: 100_000, // = super many, so all context will be used + multiroundProfile: { + maxRoundsNumber: 1, + defaultProofFixChoices: 3, + proofFixPrompt: "Fix proof, please", + }, + defaultChoices: 100, // any number will work, it's not used in the chat build + }; + if (modelName !== undefined) { + return { + ...unlimitedTokensModelParams, + modelName: modelName, + } as ModelParams; + } else { + return unlimitedTokensModelParams; + } + } + + async function prepareToChatBuilderTest( + modelName: string | undefined + ): Promise<[TestMessages, ProofGenerationContext, ModelParams]> { + const [theorems, messages] = await buildTestData(); + + const proofGenerationContext: ProofGenerationContext = { + completionTarget: theorems.theoremToComplete.statement, + contextTheorems: [theorems.plusTheorem, theorems.plusAssocTheorem], + }; + const unlimitedTokensModelParams = buildUnlimitedTokensModel( + messages, + modelName + ); + return [messages, proofGenerationContext, unlimitedTokensModelParams]; + } + + function buildLimitedTokensParams( + chat: ChatHistory, + tokens: (text: string) => number, + unlimitedTokensModelParams: ModelParams + ): ModelParams { + const estimatedTokens = chat.reduce( + (sum, chatMessage) => sum + tokens(chatMessage.content), + 0 + ); + const limitedTokensModelParams: ModelParams = { + ...unlimitedTokensModelParams, + maxTokensToGenerate: 100, + tokensLimit: 100 + estimatedTokens, + }; + return limitedTokensModelParams; + } + + ( + [ + [ + "TikToken tokens", + gptTurboModelName, + (text: string) => { + return calculateTokensViaTikToken(text, gptTurboModelName); + }, + ], + [ + "approx tokens", + undefined, + (text: string) => { + return approxCalculateTokens(text); + }, + ], + ] as [string, string | undefined, (text: string) => number][] + ).forEach(([tokensMethodName, modelName, tokens]) => { + test(`Test proof-generation-chat builder: complete, ${tokensMethodName}`, async () => { + const [ + messages, + proofGenerationContext, + unlimitedTokensModelParams, + ] = await prepareToChatBuilderTest(modelName); + + const twoTheoremsChat = buildProofGenerationChat( + proofGenerationContext, + unlimitedTokensModelParams + ).chat; + expect(twoTheoremsChat).toEqual(messages.proofGenerationChat); + }); + + test(`Test proof-generation-chat builder: only 1/2 theorem, ${tokensMethodName}`, async () => { + const [ + messages, + proofGenerationContext, + unlimitedTokensModelParams, + ] = await prepareToChatBuilderTest(modelName); + + const expectedChat = [ + messages.systemMessage, + messages.plusTheoremStatement, + messages.plusTheoremProof, + messages.theoremToCompleteStatement, + ]; + const limitedTokensModelParams = buildLimitedTokensParams( + expectedChat, + tokens, + unlimitedTokensModelParams + ); + + const oneTheoremChat = buildProofGenerationChat( + proofGenerationContext, + limitedTokensModelParams + ).chat; + expect(oneTheoremChat).toEqual(expectedChat); + }); + + function buildProofFixChatFromContext( + messages: TestMessages, + proofFixPrompt: string, + theoremsMessages: ChatMessage[], + proofVersionsMessages: ChatMessage[] + ): ChatHistory { + return [ + messages.systemMessage, + ...theoremsMessages, + messages.theoremToCompleteStatement, + ...proofVersionsMessages, + proofVersionToChat(incorrectProof)[0], + { + role: "user", + content: createProofFixMessage( + incorrectProof.diagnostic!, + proofFixPrompt + ), + }, + ]; + } + + test(`Test proof-fix-chat builder: complete, ${tokensMethodName}`, async () => { + const [ + messages, + proofGenerationContext, + unlimitedTokensModelParams, + ] = await prepareToChatBuilderTest(modelName); + + const expectedChat = buildProofFixChatFromContext( + messages, + unlimitedTokensModelParams.multiroundProfile.proofFixPrompt, + [ + messages.plusTheoremStatement, + messages.plusTheoremProof, + messages.plusAssocTheoremStatement, + messages.plusAssocTheoremProof, + ], + proofVersionToChat(misspelledProof) + ); + + const completeProofFixChat = buildProofFixChat( + proofGenerationContext, + [misspelledProof, incorrectProof], + unlimitedTokensModelParams + ).chat; + expect(completeProofFixChat).toEqual(expectedChat); + }); + + test(`Test proof-fix-chat builder: all diagnostics & only 1/2 theorem, ${tokensMethodName}`, async () => { + const [ + messages, + proofGenerationContext, + unlimitedTokensModelParams, + ] = await prepareToChatBuilderTest(modelName); + + const expectedChat = buildProofFixChatFromContext( + messages, + unlimitedTokensModelParams.multiroundProfile.proofFixPrompt, + [messages.plusTheoremStatement, messages.plusTheoremProof], + proofVersionToChat(misspelledProof) + ); + const limitedTokensModelParams = buildLimitedTokensParams( + expectedChat, + tokens, + unlimitedTokensModelParams + ); + + const allDiagnosticsOneTheoremChat = buildProofFixChat( + proofGenerationContext, + [misspelledProof, incorrectProof], + limitedTokensModelParams + ).chat; + expect(allDiagnosticsOneTheoremChat).toEqual(expectedChat); + }); + + test(`Test proof-fix-chat builder: no extra diagnostics & theorems, ${tokensMethodName}`, async () => { + const [ + messages, + proofGenerationContext, + unlimitedTokensModelParams, + ] = await prepareToChatBuilderTest(modelName); + + const expectedChat = buildProofFixChatFromContext( + messages, + unlimitedTokensModelParams.multiroundProfile.proofFixPrompt, + [], + [] + ); + const limitedTokensModelParams = buildLimitedTokensParams( + expectedChat, + tokens, + unlimitedTokensModelParams + ); + + const noExtraContextChat = buildProofFixChat( + proofGenerationContext, + [misspelledProof, incorrectProof], + limitedTokensModelParams + ).chat; + expect(noExtraContextChat).toEqual(expectedChat); + }); + }); +}); diff --git a/src/test/llm/llmServices/utils/chatTokensFitter.test.ts b/src/test/llm/llmServices/utils/chatTokensFitter.test.ts new file mode 100644 index 00000000..433f606c --- /dev/null +++ b/src/test/llm/llmServices/utils/chatTokensFitter.test.ts @@ -0,0 +1,204 @@ +import { expect } from "earl"; + +import { ConfigurationError } from "../../../../llm/llmServiceErrors"; +import { theoremToChatItem } from "../../../../llm/llmServices/utils/chatFactory"; +import { ChatTokensFitter } from "../../../../llm/llmServices/utils/chatTokensFitter"; +import { chatItemToContent } from "../../../../llm/llmServices/utils/chatUtils"; + +import { Theorem } from "../../../../coqParser/parsedTypes"; +import { parseTheoremsFromCoqFile } from "../../../commonTestFunctions/coqFileParser"; +import { + approxCalculateTokens, + calculateTokensViaTikToken, +} from "../../llmSpecificTestUtils/calculateTokens"; +import { gptTurboModelName } from "../../llmSpecificTestUtils/constants"; + +suite("[LLMService-s utils] ChatTokensFitter test", () => { + async function readTwoTheorems(): Promise { + const twoTheorems = await parseTheoremsFromCoqFile([ + "small_document.v", + ]); + expect(twoTheorems).toHaveLength(2); + return twoTheorems; + } + + interface TestParams { + modelName?: string; + maxTokensToGenerate: number; + tokensLimit: number; + systemMessage: string; + completionTarget: string; + contextTheorems: Theorem[]; + } + + function countTheoremsPickedFromContext(testParams: TestParams): number { + const fitter = new ChatTokensFitter( + testParams.maxTokensToGenerate, + testParams.tokensLimit, + testParams.modelName + ); + try { + fitter.fitRequiredMessage({ + role: "system", + content: testParams.systemMessage, + }); + fitter.fitRequiredMessage({ + role: "user", + content: testParams.completionTarget, + }); + const fittedTheorems = fitter.fitOptionalObjects( + testParams.contextTheorems, + (theorem) => chatItemToContent(theoremToChatItem(theorem)) + ); + return fittedTheorems.length; + } finally { + fitter.dispose(); + } + } + + test("No theorems", () => { + const fittedTheoremsNumber = countTheoremsPickedFromContext({ + maxTokensToGenerate: 100, + tokensLimit: 1000, + systemMessage: "You are a friendly assistant", + completionTarget: "doesn't matter", + contextTheorems: [], + }); + expect(fittedTheoremsNumber).toEqual(0); + }); + + test("Required messages do not fit", async () => { + const twoTheorems = await readTwoTheorems(); + expect(() => { + countTheoremsPickedFromContext({ + maxTokensToGenerate: 1000, + tokensLimit: 1000, + systemMessage: "You are a friendly assistant", + completionTarget: "doesn't matter", + contextTheorems: twoTheorems, + }); + }).toThrow(ConfigurationError, "required content cannot be fitted"); + }); + + test("Two theorems, no overflow", async () => { + const twoTheorems = await readTwoTheorems(); + const fittedTheoremsNumber = countTheoremsPickedFromContext({ + maxTokensToGenerate: 1000, + tokensLimit: 10000, + systemMessage: "You are a friendly assistant", + completionTarget: "doesn't matter", + contextTheorems: twoTheorems, + }); + expect(fittedTheoremsNumber).toEqual(2); + }); + + test("Two theorems, overflow after first", async () => { + const twoTheorems = await readTwoTheorems(); + const statementTokens = approxCalculateTokens(twoTheorems[0].statement); + const theoremProof = twoTheorems[0].proof?.onlyText() ?? ""; + const proofTokens = approxCalculateTokens(theoremProof); + const fittedTheoremsNumber = countTheoremsPickedFromContext({ + maxTokensToGenerate: 1000, + tokensLimit: 1000 + statementTokens + proofTokens, + systemMessage: "", + completionTarget: "", + contextTheorems: twoTheorems, + }); + expect(fittedTheoremsNumber).toEqual(1); + }); + + test("Two theorems, overflow almost before first", async () => { + const twoTheorems = await readTwoTheorems(); + const statementTokens = approxCalculateTokens(twoTheorems[0].statement); + const theoremProof = twoTheorems[0].proof?.onlyText() ?? ""; + const proofTokens = approxCalculateTokens(theoremProof); + const fittedTheoremsNumber = countTheoremsPickedFromContext({ + maxTokensToGenerate: 1000, + tokensLimit: 1000 + statementTokens + proofTokens - 1, + systemMessage: "", + completionTarget: "", + contextTheorems: twoTheorems, + }); + expect(fittedTheoremsNumber).toEqual(0); + }); + + test("Two theorems, overflow after first with tiktoken", async () => { + const twoTheorems = await readTwoTheorems(); + const statementTokens = calculateTokensViaTikToken( + twoTheorems[0].statement, + gptTurboModelName + ); + const theoremProof = twoTheorems[0].proof?.onlyText() ?? ""; + const proofTokens = calculateTokensViaTikToken( + theoremProof, + gptTurboModelName + ); + const fittedTheoremsNumber = countTheoremsPickedFromContext({ + modelName: gptTurboModelName, + maxTokensToGenerate: 1000, + tokensLimit: 1000 + statementTokens + proofTokens, + systemMessage: "", + completionTarget: "", + contextTheorems: twoTheorems, + }); + expect(fittedTheoremsNumber).toEqual(1); + }); + + const shortText = "This is a test text"; + const longText = + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."; + + test("Test if two tokenizers are similar: short text", () => { + const tiktokenTokens = calculateTokensViaTikToken( + shortText, + gptTurboModelName + ); + const approxTokens = approxCalculateTokens(shortText); + expect(tiktokenTokens).toBeCloseTo(approxTokens, 2); + }); + + test("Test if two tokenizers are similar: long text", () => { + const tiktokenTokens = calculateTokensViaTikToken( + longText, + gptTurboModelName + ); + const approxTokens = approxCalculateTokens(longText); + expect(tiktokenTokens).toBeCloseTo(approxTokens, 20); + }); + + function estimateTokensWithFitter( + modelName: string, + text: string, + maxTokensToGenerate: number + ): number { + const fitter = new ChatTokensFitter( + maxTokensToGenerate, + 1000000, + modelName + ); + try { + fitter.fitRequiredMessage({ + role: "user", // doesn't matter + content: text, + }); + return fitter.estimateTokens().maxTokensInTotal; + } finally { + fitter.dispose(); + } + } + + test("Test `estimateTokens`", () => { + const tiktokenTokens = calculateTokensViaTikToken( + longText, + gptTurboModelName + ); + const maxTokensToGenerate = 100; + expect( + estimateTokensWithFitter( + gptTurboModelName, + longText, + maxTokensToGenerate + ) + ).toEqual(tiktokenTokens + maxTokensToGenerate); + }); +}); diff --git a/src/test/llm/llmServices/utils/generationsLogger/generationsLogger.test.ts b/src/test/llm/llmServices/utils/generationsLogger/generationsLogger.test.ts new file mode 100644 index 00000000..cd3bc10b --- /dev/null +++ b/src/test/llm/llmServices/utils/generationsLogger/generationsLogger.test.ts @@ -0,0 +1,437 @@ +import { expect } from "earl"; +import * as tmp from "tmp"; + +import { + ConfigurationError, + GenerationFailedError, + LLMServiceError, +} from "../../../../../llm/llmServiceErrors"; +import { + ChatHistory, + EstimatedTokens, +} from "../../../../../llm/llmServices/chat"; +import { + LLMServiceRequest, + LLMServiceRequestFailed, + LLMServiceRequestSucceeded, +} from "../../../../../llm/llmServices/llmService"; +import { + ModelParams, + OpenAiModelParams, + PredefinedProofsModelParams, +} from "../../../../../llm/llmServices/modelParams"; +import { + GenerationsLogger, + GenerationsLoggerSettings, +} from "../../../../../llm/llmServices/utils/generationsLogger/generationsLogger"; +import { + DebugLoggerRecord, + LoggerRecord, +} from "../../../../../llm/llmServices/utils/generationsLogger/loggerRecord"; +import { SyncFile } from "../../../../../llm/llmServices/utils/generationsLogger/syncFile"; +import { nowTimestampMillis } from "../../../../../llm/llmServices/utils/time"; + +import { + gptTurboModelName, + testModelId, +} from "../../../llmSpecificTestUtils/constants"; +import { DummyLLMService } from "../../../llmSpecificTestUtils/dummyLLMService"; + +suite("[LLMService-s utils] GenerationsLogger test", () => { + const predefinedProofs = [ + "intros.", + "reflexivity.", + "auto.", + "auto.\nintro.", + ]; + const mockParamsBase: ModelParams = { + modelId: testModelId, + systemPrompt: "hi system", + maxTokensToGenerate: 10000, + tokensLimit: 1000000, + multiroundProfile: { + maxRoundsNumber: 1, + defaultProofFixChoices: 1, + proofFixPrompt: "fix it", + }, + defaultChoices: 1, + }; + const mockParams: PredefinedProofsModelParams = { + ...mockParamsBase, + tactics: predefinedProofs, + }; + const mockOpenAiParams: OpenAiModelParams = { + ...mockParamsBase, + modelName: gptTurboModelName, + apiKey: "very sensitive api key", + temperature: 1, + }; + // different from `defaultChoices`, it's a real-life case + const mockChoices = 2; + const mockEstimatedTokens: EstimatedTokens = { + messagesTokens: 100, + maxTokensToGenerate: 80, + maxTokensInTotal: 180, + }; + const mockChat: ChatHistory = [ + { + role: "system", + content: "hello from system!", + }, + { + role: "user", + content: "hello from user!\nI love multiline!", + }, + { + role: "assistant", + content: "hello from assistant!", + }, + ]; + const mockProofs = ["auto.\nintro.", "auto."]; + + async function withGenerationsLogger( + settings: GenerationsLoggerSettings, + block: (generationsLogger: GenerationsLogger) => Promise + ): Promise { + const generationsLogger = new GenerationsLogger( + tmp.fileSync().name, + settings + ); + try { + await block(generationsLogger); + } finally { + generationsLogger.dispose(); + } + } + + async function withTestGenerationsLogger( + loggerDebugMode: boolean, + block: (generationsLogger: GenerationsLogger) => Promise + ): Promise { + return withGenerationsLogger( + { + debug: loggerDebugMode, + paramsPropertiesToCensor: {}, + cleanLogsOnStart: true, + }, + block + ); + } + + function buildMockRequest( + generationsLogger: GenerationsLogger, + params: ModelParams = mockParams + ) { + const llmService = new DummyLLMService(generationsLogger); + const mockRequest: LLMServiceRequest = { + llmService: llmService, + params: params, + choices: mockChoices, + analyzedChat: { + chat: mockChat, + estimatedTokens: mockEstimatedTokens, + }, + }; + return mockRequest; + } + + function succeeded( + mockRequest: LLMServiceRequest + ): LLMServiceRequestSucceeded { + return { + ...mockRequest, + generatedRawProofs: mockProofs, + }; + } + + function failed( + mockRequest: LLMServiceRequest, + error: Error + ): LLMServiceRequestFailed { + return { + ...mockRequest, + llmServiceError: new GenerationFailedError(error), + }; + } + + async function writeLogs( + generationsLogger: GenerationsLogger + ): Promise { + const mockRequest = buildMockRequest(generationsLogger); + generationsLogger.logGenerationSucceeded(succeeded(mockRequest)); + generationsLogger.logGenerationFailed( + failed(mockRequest, Error("dns error")) + ); + generationsLogger.logGenerationSucceeded(succeeded(mockRequest)); + generationsLogger.logGenerationFailed( + failed(mockRequest, Error("network failed")) + ); + generationsLogger.logGenerationFailed( + failed( + mockRequest, + Error("tokens limit exceeded\nunfortunately, many times") + ) + ); + } + const logsSinceLastSuccessInclusiveCnt = 3; + const logsWrittenInTotalCnt = 5; + + function readAndCheckLogs( + expectedRecordsLength: number, + generationsLogger: GenerationsLogger + ) { + const records = generationsLogger.readLogs(); + expect(records).toHaveLength(expectedRecordsLength); + } + + [false, true].forEach((loggerDebugMode) => { + const testNamePostfix = loggerDebugMode + ? "[debug true]" + : "[debug false]"; + test(`Simple write-read ${testNamePostfix}`, async () => { + await withTestGenerationsLogger( + loggerDebugMode, + async (generationsLogger) => { + await writeLogs(generationsLogger); + readAndCheckLogs( + loggerDebugMode ? 5 : 3, + generationsLogger + ); + } + ); + }); + + test(`Test \`readLogsSinceLastSuccess\` ${testNamePostfix}`, async () => { + await withTestGenerationsLogger( + loggerDebugMode, + async (generationsLogger) => { + const noRecords = + generationsLogger.readLogsSinceLastSuccess(); + expect(noRecords).toHaveLength(0); + + await writeLogs(generationsLogger); + const records = + generationsLogger.readLogsSinceLastSuccess(); + expect(records).toHaveLength( + logsSinceLastSuccessInclusiveCnt - 1 + ); + } + ); + + test(`Test read no records ${testNamePostfix}`, async () => { + await withTestGenerationsLogger( + loggerDebugMode, + async (generationsLogger) => { + expect(generationsLogger.readLogs()).toHaveLength(0); + expect( + generationsLogger.readLogsSinceLastSuccess() + ).toHaveLength(0); + generationsLogger.logGenerationSucceeded( + succeeded(buildMockRequest(generationsLogger)) + ); + expect( + generationsLogger.readLogsSinceLastSuccess() + ).toHaveLength(0); + } + ); + }); + }); + + test(`Pseudo-concurrent write-read ${testNamePostfix}`, async () => { + await withTestGenerationsLogger( + loggerDebugMode, + async (generationsLogger) => { + const logsWriters = []; + const logsWritersN = 50; + for (let i = 0; i < logsWritersN; i++) { + logsWriters.push(writeLogs(generationsLogger)); + } + Promise.all(logsWriters); + readAndCheckLogs( + loggerDebugMode + ? logsWrittenInTotalCnt * logsWritersN + : logsSinceLastSuccessInclusiveCnt, + generationsLogger + ); + } + ); + }); + }); + + test("Throws on wrong error types", async () => { + await withTestGenerationsLogger(true, async (generationsLogger) => { + const mockRequest = buildMockRequest(generationsLogger); + + expect(() => + generationsLogger.logGenerationFailed( + failed( + mockRequest, + new ConfigurationError("invalid params") + ) + ) + ).toThrow(Error); + + class DummyLLMServiceError extends LLMServiceError {} + expect(() => + generationsLogger.logGenerationFailed( + failed(mockRequest, new DummyLLMServiceError()) + ) + ).toThrow(Error); + + expect(() => + generationsLogger.logGenerationFailed( + failed( + mockRequest, + new GenerationFailedError(Error("double-wrapped error")) + ) + ) + ).toThrow(Error); + }); + }); + + test("Test censor params properties", async () => { + const censorInt = -1; + await withGenerationsLogger( + { + debug: true, + paramsPropertiesToCensor: { + apiKey: GenerationsLogger.censorString, + tokensLimit: censorInt, + }, + cleanLogsOnStart: true, + }, + async (generationsLogger) => { + const mockRequest = buildMockRequest( + generationsLogger, + mockOpenAiParams + ); + generationsLogger.logGenerationSucceeded( + succeeded(mockRequest) + ); + + // test censorship via direct file read + const fileContent = new SyncFile( + generationsLogger.filePath + ).read(); + expect( + fileContent.includes(mockOpenAiParams.apiKey) + ).toBeFalsy(); + expect( + fileContent.includes(`${mockOpenAiParams.tokensLimit}`) + ).toBeFalsy(); + + // test censorship via readLogs + const records = generationsLogger.readLogs(); + expect(records).toHaveLength(1); + const record = records[0] as DebugLoggerRecord; + expect(record).not.toBeNullish(); + + expect(record.params.tokensLimit).toEqual(censorInt); + expect((record.params as OpenAiModelParams)?.apiKey).toEqual( + GenerationsLogger.censorString + ); + } + ); + }); + + test("Test record serialization-deserealization: `SUCCESS`", async () => { + const loggerRecord = new LoggerRecord( + nowTimestampMillis(), + mockParams.modelId, + "SUCCESS", + mockChoices, + mockEstimatedTokens + ); + expect( + LoggerRecord.deserealizeFromString(loggerRecord.serializeToString()) + ).toEqual([loggerRecord, ""]); + + const debugLoggerRecord = new DebugLoggerRecord( + loggerRecord, + mockChat, + mockParams, + mockProofs + ); + expect( + DebugLoggerRecord.deserealizeFromString( + debugLoggerRecord.serializeToString() + ) + ).toEqual([debugLoggerRecord, ""]); + }); + + test("Test record serialization-deserealization: `FAILED`", async () => { + const error = Error("bad things happen"); + const loggerRecord = new LoggerRecord( + nowTimestampMillis(), + mockParams.modelId, + "FAILURE", + mockChoices, + mockEstimatedTokens, + { + typeName: error.name, + message: error.message, + } + ); + expect( + LoggerRecord.deserealizeFromString(loggerRecord.serializeToString()) + ).toEqual([loggerRecord, ""]); + + const debugLoggerRecord = new DebugLoggerRecord( + loggerRecord, + mockChat, + mockParams + ); + expect( + DebugLoggerRecord.deserealizeFromString( + debugLoggerRecord.serializeToString() + ) + ).toEqual([debugLoggerRecord, ""]); + }); + + test("Test record serialization-deserealization: undefined-s", async () => { + const loggerRecord = new LoggerRecord( + nowTimestampMillis(), + mockParams.modelId, + "SUCCESS", + mockChoices, + undefined, + undefined + ); + expect( + LoggerRecord.deserealizeFromString(loggerRecord.serializeToString()) + ).toEqual([loggerRecord, ""]); + + const debugLoggerRecord = new DebugLoggerRecord( + loggerRecord, + undefined, + mockParams, + undefined + ); + expect( + DebugLoggerRecord.deserealizeFromString( + debugLoggerRecord.serializeToString() + ) + ).toEqual([debugLoggerRecord, ""]); + }); + + test("Test record serialization-deserealization: empty lists", async () => { + const debugLoggerRecord = new DebugLoggerRecord( + new LoggerRecord( + nowTimestampMillis(), + mockParams.modelId, + "SUCCESS", + mockChoices, + undefined, + undefined + ), + [], // empty chat list + mockParams, + [] // empty generated proofs list + ); + expect( + DebugLoggerRecord.deserealizeFromString( + debugLoggerRecord.serializeToString() + ) + ).toEqual([debugLoggerRecord, ""]); + }); +}); diff --git a/src/test/llm/llmServices/utils/generationsLogger/syncFile.test.ts b/src/test/llm/llmServices/utils/generationsLogger/syncFile.test.ts new file mode 100644 index 00000000..98143793 --- /dev/null +++ b/src/test/llm/llmServices/utils/generationsLogger/syncFile.test.ts @@ -0,0 +1,84 @@ +import { expect } from "earl"; +import * as tmp from "tmp"; + +import { SyncFile } from "../../../../../llm/llmServices/utils/generationsLogger/syncFile"; + +suite("[LLMService-s utils] SyncFile test", () => { + const filePath = tmp.fileSync().name; + const file = new SyncFile(filePath); + + test("Basic operations", () => { + if (file.exists()) { + file.delete(); + } + expect(file.exists()).toBeFalsy(); + + file.createReset(); + expect(file.exists()).toBeTruthy(); + + file.append("- hello?\n"); + file.append("- coq!"); + expect(file.read()).toEqual("- hello?\n- coq!"); + + file.write("only coq"); + expect(file.read()).toEqual("only coq"); + + file.createReset(); + expect(file.read()).toEqual(""); + + file.delete(); + expect(file.exists()).toBeFalsy(); + }); + + async function appendManyLongLines( + workerId: number, + linesN: number + ): Promise { + const lines = []; + const longSuffix = "coq\t".repeat(100); + for (let i = 0; i < linesN; i++) { + lines.push(`${workerId}: ${longSuffix}\n`); + } + file.append(lines.join("")); + return "done"; + } + + // Tests that `SyncFile` actually provides synchronization for async operations. + test("Pseudo-concurrent operations", async () => { + file.createReset(); + const workers = []; + const workersN = 100; + const linesN = 100; + for (let i = 0; i < workersN; i++) { + workers.push(appendManyLongLines(i, linesN)); + } + const workersDone = await Promise.all(workers); + expect(workersDone).toEqual(new Array(workersN).fill("done")); + + const lines = file.read().split("\n").slice(0, -1); + expect(lines).toHaveLength(workersN * linesN); + + const workersLinesCnt: { [key: number]: number } = {}; + let lastLineWorkerId = -1; + for (const line of lines) { + const rawParts = line.split(":"); + expect(rawParts.length).toBeGreaterThan(1); + const workerId = parseInt(rawParts[0]); + + if (workerId in workersLinesCnt) { + expect(lastLineWorkerId).toEqual(workerId); + workersLinesCnt[workerId] += 1; + } else { + if (lastLineWorkerId !== -1) { + expect(workersLinesCnt[lastLineWorkerId]).toEqual(linesN); + } + workersLinesCnt[workerId] = 1; + } + lastLineWorkerId = workerId; + } + + for (let i = 0; i < workersN; i++) { + expect(workersLinesCnt[i]).toEqual(linesN); + } + }); +}); diff --git a/src/test/llm/llmServices/utils/paramsResolver.test.ts b/src/test/llm/llmServices/utils/paramsResolver.test.ts new file mode 100644 index 00000000..f42ccf1a --- /dev/null +++ b/src/test/llm/llmServices/utils/paramsResolver.test.ts @@ -0,0 +1,622 @@ +import { JSONSchemaType } from "ajv"; +import { expect } from "earl"; + +import { ModelParams } from "../../../../llm/llmServices/modelParams"; +import { ParamsResolver } from "../../../../llm/llmServices/utils/paramsResolvers/abstractResolvers"; +import { ValidationRules } from "../../../../llm/llmServices/utils/paramsResolvers/builders"; +import { + NoOptionalProperties, + ParamsResolverImpl, + ValidParamsResolverImpl, +} from "../../../../llm/llmServices/utils/paramsResolvers/paramsResolverImpl"; +import { UserModelParams } from "../../../../llm/userModelParams"; + +// import { UserModelParams } from "../../../../llm/userModelParams"; +import { expectParamResolutionResult } from "../../llmSpecificTestUtils/expectResolutionResult"; + +suite("[LLMService-s utils] Test `ParamsResolver`", () => { + const positiveInputValue = 5; + const negativeInputValue = -5; + const positiveDefaultValue = 6; + + function testSuccessfulSingleNumberResolution( + testName: string, + paramsResolver: ParamsResolver, + inputParams: InputType, + expectedResolvedParams: ResolveToType, + expectedParamNameInLogs: string + ) { + test(testName, () => { + const resolutionResult = paramsResolver.resolve(inputParams); + expect(resolutionResult.resolved).toEqual(expectedResolvedParams); + expect(resolutionResult.resolutionLogs).toHaveLength(1); + expectParamResolutionResult( + resolutionResult.resolutionLogs[0], + { + resultValue: positiveInputValue, + inputReadCorrectly: { + wasPerformed: true, + withValue: positiveInputValue, + }, + }, + expectedParamNameInLogs + ); + }); + } + + function testFailedSingleNumberResolution( + testName: string, + paramsResolver: ParamsResolver, + inputParams: InputType, + expectedParamNameInLogs: string + ) { + test(testName, () => { + const resolutionResult = paramsResolver.resolve(inputParams); + expect(resolutionResult.resolved).toBeNullish(); + expect(resolutionResult.resolutionLogs).toHaveLength(1); + expectParamResolutionResult( + resolutionResult.resolutionLogs[0], + { + isInvalidCause: "should be positive, but has value", + inputReadCorrectly: { + wasPerformed: true, + withValue: negativeInputValue, + }, + }, + expectedParamNameInLogs + ); + }); + } + + interface InputNumberParam { + input?: number; + } + + interface ResolvedNumberParam { + output: number; + } + + const resolvedNumberParamSchema: JSONSchemaType = { + type: "object", + properties: { + output: { type: "number" }, + }, + required: ["output"], + additionalProperties: false, + }; + + class NumberParamResolver extends ParamsResolverImpl< + InputNumberParam, + ResolvedNumberParam + > { + constructor() { + super(resolvedNumberParamSchema, "ResolvedNumberParam"); + } + + readonly output = this.resolveParam("input") + .default(() => positiveDefaultValue) + .validate(ValidationRules.bePositiveNumber); + } + + ( + [ + ["1-to-1 params", { input: positiveInputValue }], + [ + "with-extra-to-1 params", + { input: positiveInputValue, extra: true } as InputNumberParam, + ], + ] as [string, InputNumberParam][] + ).forEach(([testCase, inputParams]) => { + testSuccessfulSingleNumberResolution( + `\`ParamsResolver\` with single value: ${testCase}, success`, + new NumberParamResolver(), + inputParams, + { + output: positiveInputValue, + }, + "input" + ); + }); + + testFailedSingleNumberResolution( + "`ParamsResolver` with single value: 1-to-1 params, failure", + new NumberParamResolver(), + { + input: negativeInputValue, + }, + "input" + ); + + interface InputMixedParams { + input?: number; + complex?: ResolvedNumberParam; + extra: string; + } + + interface ResolvedMixedParams { + output: number; + complex: ResolvedNumberParam; + inserted: boolean; + } + + class MixedParamsResolver extends ParamsResolverImpl< + InputMixedParams, + ResolvedMixedParams + > { + constructor() { + super( + { + type: "object", + properties: { + output: { type: "number" }, + complex: { + type: "object", + oneOf: [resolvedNumberParamSchema], + }, + inserted: { type: "boolean" }, + }, + required: ["output", "complex", "inserted"], + }, + "ResolvedMixedParams" + ); + } + + readonly output = this.resolveParam("input") + .requiredToBeConfigured() + .validate(ValidationRules.bePositiveNumber); + + readonly complex = this.resolveParam("complex") + .default(() => { + return { output: positiveDefaultValue }; + }) + .validate([(value) => value.output! >= 0, "be non-negative"]); + + readonly inserted = this.insertParam( + () => true + ).noValidationNeeded(); + } + + test(`\`ParamsResolver\` with mixed value: success`, () => { + const paramsResolver = new MixedParamsResolver(); + const resolutionResult = paramsResolver.resolve({ + input: positiveInputValue, + complex: undefined, + extra: "will not be resolved", + }); + + expect(resolutionResult.resolved).not.toBeNullish(); + expect(resolutionResult.resolved).toEqual({ + output: positiveInputValue, + complex: { + output: positiveDefaultValue, + }, + inserted: true, + }); + + expect(resolutionResult.resolutionLogs).toHaveLength(3); + const [outputLog, complexLog, insertedLog] = + resolutionResult.resolutionLogs; + + expectParamResolutionResult( + outputLog, + { + resultValue: positiveInputValue, + inputReadCorrectly: { + wasPerformed: true, + withValue: positiveInputValue, + }, + }, + "input" + ); + expectParamResolutionResult( + complexLog, + { + resultValue: { + output: positiveDefaultValue, + }, + resolvedWithDefault: { + wasPerformed: true, + withValue: { + output: positiveDefaultValue, + }, + }, + }, + "complex" + ); + expectParamResolutionResult( + insertedLog, + { + resultValue: true, + overriden: { + wasPerformed: true, + withValue: true, + }, + }, + undefined + ); + }); + + test("`ParamsResolver` with mixed values: failure", () => { + const paramsResolver = new MixedParamsResolver(); + const resolutionResult = paramsResolver.resolve({ + input: undefined, + complex: { + output: negativeInputValue, + }, + extra: "will not be resolved", + }); + + expect(resolutionResult.resolved).toBeNullish(); + + expect(resolutionResult.resolutionLogs).toHaveLength(3); + const [outputLog, complexLog, insertedLog] = + resolutionResult.resolutionLogs; + + expectParamResolutionResult( + outputLog, + { + isInvalidCause: + "neither a user value nor a default one is specified", + }, + "input" + ); + expectParamResolutionResult( + complexLog, + { + isInvalidCause: "should be non-negative, but has value", + inputReadCorrectly: { + wasPerformed: true, + withValue: { + output: negativeInputValue, + }, + }, + }, + "complex" + ); + expectParamResolutionResult( + insertedLog, + { + resultValue: true, + overriden: { + wasPerformed: true, + withValue: true, + }, + }, + undefined + ); + }); + + class ParamsResolverWithNonResolverProperty extends NumberParamResolver { + readonly nonResolverProperty: string = "i'm not a resolver!"; + } + + class ParamsResolverWithUnfinishedBuilder extends NumberParamResolver { + readonly unfinishedBuilder = this.resolveParam( + "input" + ).override(() => 6); + } + + class ParamsResolverWithNonCertifiedResolverProperty extends NumberParamResolver { + readonly nonCertifiedResolver = { + resolve() {}, + }; + } + + test("`ParamsResolver` configured incorrectly: property of non-`ParamsResolver` type", () => { + expect(() => + new ParamsResolverWithNonResolverProperty().resolve({ + input: positiveInputValue, + }) + ).toThrow(Error, "configured incorrectly"); + expect(() => + new ParamsResolverWithUnfinishedBuilder().resolve({ + input: positiveInputValue, + }) + ).toThrow(Error, "configured incorrectly"); + expect(() => + new ParamsResolverWithNonCertifiedResolverProperty().resolve({ + input: positiveInputValue, + }) + ).toThrow(Error, "configured incorrectly"); + }); + + class EmptyParamsResolver extends ParamsResolverImpl< + InputNumberParam, + ResolvedNumberParam + > { + constructor() { + super(resolvedNumberParamSchema, "ResolvedNumberParam"); + } + } + + test("`ParamsResolver` configured incorrectly: not enough parameters", () => { + const paramsResolver = new EmptyParamsResolver(); + expect(() => + paramsResolver.resolve({ + input: positiveInputValue, + }) + ).toThrow(Error, "configured incorrectly"); + }); + + class WrongTypeParamsResolver extends ParamsResolverImpl< + InputNumberParam, + ResolvedNumberParam + > { + constructor() { + super(resolvedNumberParamSchema, "ResolvedNumberParam"); + } + + output = this.resolveParam("input") + .default(() => "string type is the wrong one" as any) + .noValidationNeeded(); + } + + test("`ParamsResolver` configured incorrectly: parameter of wrong type", () => { + const paramsResolver = new WrongTypeParamsResolver(); + expect(() => + paramsResolver.resolve({ + input: undefined, + }) + ).toThrow(Error, "configured incorrectly"); + }); + + class HiddenWrongTypeParamsResolver extends ParamsResolverImpl< + InputNumberParam, + ResolvedNumberParam + > { + constructor() { + super(resolvedNumberParamSchema, "ResolvedNumberParam"); + } + + output = this.resolveParam("input") + .default(() => "string type is the wrong one" as any) + .noValidationNeeded(); + } + + test("`ParamsResolver` configured incorrectly: parameter of hidden wrong type", () => { + const paramsResolver = new HiddenWrongTypeParamsResolver(); + expect(() => + paramsResolver.resolve({ + input: undefined, + }) + ).toThrow(Error, "configured incorrectly"); + }); + + type NumberParamResolverIsValid = + NumberParamResolver extends ValidParamsResolverImpl< + InputNumberParam, + ResolvedNumberParam + > + ? "true" + : "false"; + + type EmptyParamsResolverIsValid = + EmptyParamsResolver extends ValidParamsResolverImpl< + InputNumberParam, + ResolvedNumberParam + > + ? "true" + : "false"; + + type WrongTypeParamsResolverIsValid = + WrongTypeParamsResolver extends ValidParamsResolverImpl< + InputNumberParam, + ResolvedNumberParam + > + ? "true" + : "false"; + + test("Test `ValidParamsResolverImpl` constraint", () => { + // @ts-ignore variable is needed only for a type check + const _shouldBeTrue: NumberParamResolverIsValid = "true"; + // @ts-ignore variable is needed only for a type check + const _shouldBeFalse1: EmptyParamsResolverIsValid = "false"; + // @ts-ignore variable is needed only for a type check + const _shouldBeFalse2: WrongTypeParamsResolverIsValid = "false"; + }); + + interface InputParamsWithNestedParam { + nestedParam?: InputNumberParam; + } + + interface ResolvedParamsWithNestedParam { + resolvedNestedParam: ResolvedNumberParam; + } + + const resolvedParamsWithNestedParamSchema: JSONSchemaType = + { + type: "object", + properties: { + resolvedNestedParam: { + type: "object", + oneOf: [resolvedNumberParamSchema], + }, + }, + required: ["resolvedNestedParam"], + }; + + class ParamsResolverWithNestedResolver extends ParamsResolverImpl< + InputParamsWithNestedParam, + ResolvedParamsWithNestedParam + > { + constructor() { + super( + resolvedParamsWithNestedParamSchema, + "ResolvedParamsWithNestedParam" + ); + } + + readonly resolvedNestedParam = this.resolveNestedParams( + "nestedParam", + new NumberParamResolver() + ); + } + + testSuccessfulSingleNumberResolution( + "Test `resolveNestedParams`: basic, success", + new ParamsResolverWithNestedResolver(), + { + nestedParam: { + input: positiveInputValue, + }, + }, + { + resolvedNestedParam: { + output: positiveInputValue, + }, + }, + "nestedParam.input" + ); + + testFailedSingleNumberResolution( + "Test `resolveNestedParams`: basic, failure", + new ParamsResolverWithNestedResolver(), + { + nestedParam: { + input: negativeInputValue, + }, + }, + "nestedParam.input" + ); + + interface InputParamsWithDeepNestedParam { + deepNestedParam?: InputParamsWithNestedParam; + } + + interface ResolvedParamsWithDeepNestedParam { + resolvedDeepNestedParam: ResolvedParamsWithNestedParam; + } + + class ParamsResolverWithDeepNestedResolver extends ParamsResolverImpl< + InputParamsWithDeepNestedParam, + ResolvedParamsWithDeepNestedParam + > { + constructor() { + super( + { + type: "object", + properties: { + resolvedDeepNestedParam: { + type: "object", + oneOf: [resolvedParamsWithNestedParamSchema], + }, + }, + required: ["resolvedDeepNestedParam"], + }, + "ResolvedParamsWithDeepNestedParam" + ); + } + + readonly resolvedDeepNestedParam = this.resolveNestedParams( + "deepNestedParam", + new ParamsResolverWithNestedResolver() + ); + } + + testSuccessfulSingleNumberResolution( + "Test `resolveNestedParams`: deep nesting & all defined, success", + new ParamsResolverWithDeepNestedResolver(), + { + deepNestedParam: { + nestedParam: { + input: positiveInputValue, + }, + }, + }, + { + resolvedDeepNestedParam: { + resolvedNestedParam: { + output: positiveInputValue, + }, + }, + }, + "deepNestedParam.nestedParam.input" + ); + + test("Test `resolveNestedParams`: deep nesting & undefined in the middle, success", () => { + const resolutionResult = + new ParamsResolverWithDeepNestedResolver().resolve({ + deepNestedParam: { + nestedParam: undefined, + }, + }); + expect(resolutionResult.resolved).toEqual({ + resolvedDeepNestedParam: { + resolvedNestedParam: { + output: positiveDefaultValue, + }, + }, + }); + expect(resolutionResult.resolutionLogs).toHaveLength(1); + expectParamResolutionResult( + resolutionResult.resolutionLogs[0], + { + resultValue: positiveDefaultValue, + resolvedWithDefault: { + wasPerformed: true, + withValue: positiveDefaultValue, + }, + }, + "deepNestedParam.nestedParam.input" + ); + }); + + testFailedSingleNumberResolution( + "Test `resolveNestedParams`: deep nesting, failure", + new ParamsResolverWithDeepNestedResolver(), + { + deepNestedParam: { + nestedParam: { + input: negativeInputValue, + }, + }, + }, + "deepNestedParam.nestedParam.input" + ); + + type ModelParamsHasNoOptionalProperties = [ + NoOptionalProperties, + ] extends [never] + ? "false" + : "true"; + + type UserModelParamsHasNoOptionalProperties = [ + NoOptionalProperties, + ] extends [never] + ? "false" + : "true"; + + test("Test `NoOptionalProperties` constraint", () => { + // @ts-ignore variable is needed only for a type check + const _shouldBeTrue: ModelParamsHasNoOptionalProperties = "true"; + // @ts-ignore variable is needed only for a type check + const _shouldBeFalse: UserModelParamsHasNoOptionalProperties = "false"; + }); + + /* + * The following code snippets should not compile (after adding required imports). + * Uncomment them to test. + */ + + // class ResolveToTypeHasOptionalProperties extends ParamsResolverImpl< + // UserModelParams, + // UserModelParams + // > { + // constructor() { + // super(userModelParamsSchema, "UserModelParams"); + // } + // } + + // class AttemptToSpecifyNonExistingPropertyOfInputType extends ParamsResolverImpl< + // UserModelParams, + // ModelParams + // > { + // constructor() { + // super(modelParamsSchema, "ModelParams"); + // } + + // readonly unicorn = this.resolveParam("unicorn") + // .requiredToBeConfigured() + // .noValidationNeeded(); + // } +}); diff --git a/src/test/llm/llmServices/utils/singleParamResolver.test.ts b/src/test/llm/llmServices/utils/singleParamResolver.test.ts new file mode 100644 index 00000000..4d283b72 --- /dev/null +++ b/src/test/llm/llmServices/utils/singleParamResolver.test.ts @@ -0,0 +1,396 @@ +import { expect } from "earl"; + +import { + ParamsResolver, + ResolutionActionDetailedResult, + SingleParamResolver, +} from "../../../../llm/llmServices/utils/paramsResolvers/abstractResolvers"; +import { ResolutionActionResult } from "../../../../llm/llmServices/utils/paramsResolvers/abstractResolvers"; +import { + SingleParamResolverBuilder, + ValidationRules, + newParam, + resolveParam, +} from "../../../../llm/llmServices/utils/paramsResolvers/builders"; + +import { + ResolutionResultAddOns, + expectParamResolutionResult, +} from "../../llmSpecificTestUtils/expectResolutionResult"; + +suite("[LLMService-s utils] Test single parameter resolution", () => { + const inputParamName = "inputParam"; + + interface InputParams { + inputParam: T | undefined; + } + + function testSingleParamResolution( + testName: string, + inputValue: T | undefined, + buildResolver: ( + builder: SingleParamResolverBuilder, T> + ) => SingleParamResolver, T>, + expectedDefinedValues: ResolutionResultAddOns + ) { + test(testName, () => { + const paramResolverBuilder = resolveParam, T>( + inputParamName + ); + const paramResolver = buildResolver(paramResolverBuilder); + const resolutionResult = paramResolver.resolveParam({ + inputParam: inputValue, + }); + expectParamResolutionResult( + resolutionResult, + expectedDefinedValues, + inputParamName + ); + }); + } + + testSingleParamResolution( + "Test required param: value is specified", + false, + (builder) => builder.requiredToBeConfigured().noValidationNeeded(), + { + resultValue: false, + inputReadCorrectly: { + wasPerformed: true, + withValue: false, + }, + } + ); + + testSingleParamResolution( + "Test required param: no value specified", + undefined, + (builder) => builder.requiredToBeConfigured().noValidationNeeded(), + { + isInvalidCause: + "neither a user value nor a default one is specified", + } + ); + + testSingleParamResolution( + "Test required param: value of wrong type is specified, but there is nothing we can do", + "definitely not a number" as any, + (builder) => builder.requiredToBeConfigured().noValidationNeeded(), + { + resultValue: "definitely not a number" as any, + inputReadCorrectly: { + wasPerformed: true, + withValue: "definitely not a number" as any, + }, + } + ); + + testSingleParamResolution( + "Test resolution with default: value is already defined", + false, + (builder) => builder.default(() => true).noValidationNeeded(), + { + resultValue: false, + inputReadCorrectly: { + wasPerformed: true, + withValue: false, + }, + } + ); + + testSingleParamResolution( + "Test resolution with default: resolved with default", + undefined, + (builder) => builder.default(() => true).noValidationNeeded(), + { + resultValue: true, + resolvedWithDefault: { + wasPerformed: true, + withValue: true, + }, + } + ); + + testSingleParamResolution( + "Test resolution with default: failed", + undefined, + (builder) => + builder + .default( + (inputParams) => + inputParams.inputParam !== undefined ? true : undefined, + "Please configure the parameter with a value other than `undefined`." + ) + .noValidationNeeded(), + { + isInvalidCause: + "Please configure the parameter with a value other than `undefined`.", + resolvedWithDefault: { + wasPerformed: true, + withValue: undefined, + }, + } + ); + + ( + [ + [false, "specified value"], + [undefined, "no value specified"], + [true, 'already specified "true" value is not'], + ] as [boolean | undefined, string][] + ).forEach(([value, testCaseName]) => { + const inputReadCorrectly: ResolutionActionResult = + value === undefined + ? { wasPerformed: false } + : { wasPerformed: true, withValue: value }; + const overriden: ResolutionActionDetailedResult = + value === true + ? { wasPerformed: false } + : { + wasPerformed: true, + withValue: true, + message: "is always true", + }; + + testSingleParamResolution( + `Test override with "always true": ${testCaseName} overriden`, + value, + (builder) => + builder + .override(() => true, "is always true") + .requiredToBeConfigured() + .noValidationNeeded(), + { + resultValue: true, + inputReadCorrectly: inputReadCorrectly, + overriden: overriden, + } + ); + + testSingleParamResolution( + `Test override with mock: ${testCaseName} overriden`, + value, + (builder) => builder.overrideWithMock(() => true), + { resultValue: true, inputReadCorrectly: inputReadCorrectly } + ); + }); + + ( + [ + [5, "specified value overriden", 6], + [ + undefined, + "no value specified (is not overriden) resolved with default", + 1, + ], + [0, "specified value is forced to be resolved with default", 1], + ] as [number | undefined, string, number][] + ).forEach(([value, testCaseName, expectedResultValue]) => { + const overriderMessage = + "is 6 if value is defined and non-zero; otherwise should be resolved with default"; + const forceDefaultResolution = value === undefined || value === 0; + const overriden: ResolutionActionDetailedResult = + value === undefined + ? { wasPerformed: false } + : { + wasPerformed: true, + withValue: forceDefaultResolution ? undefined : 6, + message: overriderMessage, + }; + testSingleParamResolution( + `Test conditional override with default resolution: ${testCaseName}`, + value, + (builder) => + builder + .override(() => { + if (forceDefaultResolution) { + return undefined; + } + return 6; + }, overriderMessage) + .default(() => 1) + .noValidationNeeded(), + { + resultValue: expectedResultValue, + inputReadCorrectly: { + wasPerformed: value === undefined ? false : true, + withValue: value, + }, + overriden: overriden, + resolvedWithDefault: { + wasPerformed: forceDefaultResolution ? true : false, + withValue: forceDefaultResolution ? 1 : undefined, + }, + } + ); + }); + + testSingleParamResolution( + "Test param validation: 1 rule, success", + true, + (builder) => + builder + .requiredToBeConfigured() + .validate([(value) => value, "be true"]), + { + resultValue: true, + inputReadCorrectly: { + wasPerformed: true, + withValue: true, + }, + } + ); + + testSingleParamResolution( + "Test param validation: 1 rule, failure", + false, + (builder) => + builder + .requiredToBeConfigured() + .validate([(value) => value, "be true"]), + { + isInvalidCause: "should be true, but has value", + inputReadCorrectly: { + wasPerformed: true, + withValue: false, + }, + } + ); + + testSingleParamResolution( + "Test param validation: multiple rules, success", + 5, + (builder) => + builder + .requiredToBeConfigured() + .validate( + [(value) => value > 0, "be positive"], + [(value) => value < 100, "be less than 100"] + ), + { + resultValue: 5, + inputReadCorrectly: { + wasPerformed: true, + withValue: 5, + }, + } + ); + + testSingleParamResolution( + "Test param validation: multiple rules, first fails", + -5, + (builder) => + builder + .requiredToBeConfigured() + .validate( + [(value) => value > 0, "be positive"], + [(value) => value < 100, "be less than 100"] + ), + { + isInvalidCause: "should be positive, but has value", + inputReadCorrectly: { + wasPerformed: true, + withValue: -5, + }, + } + ); + + testSingleParamResolution( + "Test validate at runtime: pass", + true, + (builder) => builder.requiredToBeConfigured().validateAtRuntimeOnly(), + { + resultValue: true, + inputReadCorrectly: { + wasPerformed: true, + withValue: true, + }, + } + ); + + test("Test no property with such name: does not compile", () => { + /* + * The following code snippet should not compile. + * Uncomment it to test. + */ + // const inputParams = { + // inputParam: true, + // }; + // const unknownParamName = "unknownParam"; + // resolveParam< + // { + // inputParam: boolean; + // }, + // boolean + // >(unknownParamName) + // .requiredToBeConfigured() + // .noValidationNeeded() + // .resolveParam(inputParams); + }); + + test("Test `newParam`: success", () => { + const paramResolver = newParam, number>( + () => 5 + ).noValidationNeeded(); + const inputParams = { + inputParam: false, + }; + const resolutionResult = paramResolver.resolveParam(inputParams); + expectParamResolutionResult( + resolutionResult, + { + resultValue: 5, + overriden: { + wasPerformed: true, + withValue: 5, + }, + }, + undefined + ); + }); + + test("Test builders return `SingleParamResolver` as valid `ParamsResolver`", () => { + const paramResolver = resolveParam, number>( + inputParamName + ) + .requiredToBeConfigured() + .validate(ValidationRules.bePositiveNumber) as ParamsResolver< + InputParams, + number + >; + + const successResult = paramResolver.resolve({ + inputParam: 5, + }); + expect(successResult.resolved).toEqual(5); + expect(successResult.resolutionLogs).toHaveLength(1); + expectParamResolutionResult( + successResult.resolutionLogs[0], + { + resultValue: 5, + inputReadCorrectly: { + wasPerformed: true, + withValue: 5, + }, + }, + inputParamName + ); + + const failureResult = paramResolver.resolve({ + inputParam: -5, + }); + expect(failureResult.resolved).toBeNullish(); + expect(failureResult.resolutionLogs).toHaveLength(1); + expectParamResolutionResult( + failureResult.resolutionLogs[0], + { + inputReadCorrectly: { + wasPerformed: true, + withValue: -5, + }, + isInvalidCause: "should be positive, but has value", + }, + inputParamName + ); + }); +}); diff --git a/src/test/llm/llmServices/utils/time.test.ts b/src/test/llm/llmServices/utils/time.test.ts new file mode 100644 index 00000000..6a4560ae --- /dev/null +++ b/src/test/llm/llmServices/utils/time.test.ts @@ -0,0 +1,118 @@ +import { expect } from "earl"; + +import { + Time, + millisToTime, + time, + timeToMillis, + timeToString, + timeZero, +} from "../../../../llm/llmServices/utils/time"; + +suite("[LLMService-s utils] Time utils test", () => { + const zero: Time = { + millis: 0, + seconds: 0, + minutes: 0, + hours: 0, + days: 0, + }; + + const fiveSeconds: Time = { + millis: 0, + seconds: 5, + minutes: 0, + hours: 0, + days: 0, + }; + + const fiveSecondsInMillis: Time = { + millis: 5000, + seconds: 0, + minutes: 0, + hours: 0, + days: 0, + }; + + const twoDays: Time = { + millis: 0, + seconds: 0, + minutes: 0, + hours: 0, + days: 2, + }; + + const manyDays: Time = { + millis: 0, + seconds: 0, + minutes: 0, + hours: 0, + days: 99999, + }; + + const mixedResolved: Time = { + millis: 100, + seconds: 40, + minutes: 30, + hours: 20, + days: 10, + }; + + const mixedUnresolved: Time = { + millis: 10100, + seconds: 210, + minutes: 147, + hours: 66, + days: 8, + }; + + const withBothEndZeros: Time = { + millis: 0, + seconds: 0, + minutes: 40, + hours: 2, + days: 0, + }; + + test("Test `timeToMillis`", () => { + expect(timeToMillis(zero)).toEqual(0); + expect(timeToMillis(fiveSeconds)).toEqual(5000); + expect(timeToMillis(fiveSecondsInMillis)).toEqual(5000); + expect(timeToMillis(twoDays)).toEqual(2 * 24 * 60 * 60 * 1000); + }); + + test("Test `millisToTime`", () => { + expect(millisToTime(0)).toEqual(zero); + expect(millisToTime(5000)).toEqual(fiveSeconds); + expect(millisToTime(2 * 24 * 60 * 60 * 1000)).toEqual(twoDays); + }); + + test("Test resolution through `millisToTime`", () => { + expect(millisToTime(timeToMillis(twoDays))).toEqual(twoDays); + expect(millisToTime(timeToMillis(manyDays))).toEqual(manyDays); + expect(millisToTime(timeToMillis(mixedResolved))).toEqual( + mixedResolved + ); + expect(millisToTime(timeToMillis(mixedUnresolved))).toEqual( + mixedResolved + ); + }); + + test("Test `time` factory", () => { + expect(time(5, "second")).toEqual(fiveSeconds); + expect(time(5000, "millisecond")).toEqual(fiveSeconds); + expect(time(2, "day")).toEqual(twoDays); + expect(time(2 * 24 * 60, "minute")).toEqual(twoDays); + }); + + test("Test `timeToString`", () => { + expect(timeToString(timeZero)).toEqual("0 ms"); + expect(timeToString(fiveSeconds)).toEqual("5 s"); + expect(timeToString(fiveSecondsInMillis)).toEqual("5 s"); + expect(timeToString(twoDays)).toEqual("2 d"); + expect(timeToString(mixedResolved)).toEqual( + "10 d, 20 h, 30 m, 40 s, 100 ms" + ); + expect(timeToString(withBothEndZeros)).toEqual("2 h, 40 m"); + }); +}); diff --git a/src/test/llm/llmSpecificTestUtils/calculateTokens.ts b/src/test/llm/llmSpecificTestUtils/calculateTokens.ts new file mode 100644 index 00000000..4b2346ae --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/calculateTokens.ts @@ -0,0 +1,15 @@ +import { TiktokenModel, encoding_for_model } from "tiktoken"; + +export function calculateTokensViaTikToken( + text: string, + model: TiktokenModel +): number { + const encoder = encoding_for_model(model); + const tokens = encoder.encode(text).length; + encoder.free(); + return tokens; +} + +export function approxCalculateTokens(text: string): number { + return (text.length / 4) >> 0; +} diff --git a/src/test/llm/llmSpecificTestUtils/constants.ts b/src/test/llm/llmSpecificTestUtils/constants.ts new file mode 100644 index 00000000..69518264 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/constants.ts @@ -0,0 +1,33 @@ +import { AnalyzedChatHistory } from "../../../llm/llmServices/chat"; +import { ProofGenerationContext } from "../../../llm/proofGenerationContext"; + +export const proofsToGenerate = [ + "auto.", + "left. reflexivity.", + "right. auto.", + "intros.", + "assumption.", + "something.", + "", + "reflexivity.", + "auto.", + "auto.", +]; + +export const testModelId = "test model"; + +export const gptTurboModelName = "gpt-3.5-turbo-0301"; + +export const mockChat: AnalyzedChatHistory = { + chat: [{ role: "system", content: "Generate proofs." }], + estimatedTokens: { + messagesTokens: 10, + maxTokensToGenerate: 50, + maxTokensInTotal: 60, + }, +}; + +export const mockProofGenerationContext: ProofGenerationContext = { + completionTarget: "forall n : nat, 0 + n = n", + contextTheorems: [], +}; diff --git a/src/test/llm/llmSpecificTestUtils/dummyLLMService.ts b/src/test/llm/llmSpecificTestUtils/dummyLLMService.ts new file mode 100644 index 00000000..3fe942e6 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/dummyLLMService.ts @@ -0,0 +1,123 @@ +import { + AnalyzedChatHistory, + ChatHistory, +} from "../../../llm/llmServices/chat"; +import { + ErrorsHandlingMode, + GeneratedProofImpl, + LLMServiceImpl, + ProofVersion, +} from "../../../llm/llmServices/llmService"; +import { LLMServiceInternal } from "../../../llm/llmServices/llmServiceInternal"; +import { + ModelParams, + modelParamsSchema, +} from "../../../llm/llmServices/modelParams"; +import { GenerationsLogger } from "../../../llm/llmServices/utils/generationsLogger/generationsLogger"; +import { BasicModelParamsResolver } from "../../../llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers"; +import { ProofGenerationContext } from "../../../llm/proofGenerationContext"; +import { UserModelParams } from "../../../llm/userModelParams"; + +/** + * Mock implementation that always throws on any proof-generation call. + * Its only mission is to exist: for example, it can be useful to build mock `LLMServiceRequest`-s. + * + * Additionally, it accepts `GenerationsLogger` from outside, so no resources are needed to be cleaned with `dispose`. + */ +export class DummyLLMService extends LLMServiceImpl< + UserModelParams, + ModelParams, + DummyLLMService, + DummyGeneratedProof, + DummyLLMServiceInternal +> { + protected readonly internal: DummyLLMServiceInternal; + protected readonly modelParamsResolver = new BasicModelParamsResolver( + modelParamsSchema, + "ModelParams" + ); + + constructor(generationsLogger: GenerationsLogger) { + super("DummyLLMService", undefined, true, undefined); + this.internal = new DummyLLMServiceInternal( + this, + this.eventLoggerGetter, + () => generationsLogger + ); + } + + dispose(): void {} + + generateFromChat( + _analyzedChat: AnalyzedChatHistory, + _params: ModelParams, + _choices: number, + _errorsHandlingMode?: ErrorsHandlingMode + ): Promise { + throw Error("I'm a teapot"); + } + + generateProof( + _proofGenerationContext: ProofGenerationContext, + _params: ModelParams, + _choices: number, + _errorsHandlingMode?: ErrorsHandlingMode + ): Promise { + throw Error("I'm a teapot"); + } +} + +export class DummyGeneratedProof extends GeneratedProofImpl< + ModelParams, + DummyLLMService, + DummyGeneratedProof, + DummyLLMServiceInternal +> { + constructor( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: ModelParams, + llmServiceInternal: DummyLLMServiceInternal, + previousProofVersions?: ProofVersion[] + ) { + super( + proof, + proofGenerationContext, + modelParams, + llmServiceInternal, + previousProofVersions + ); + } + + fixProof( + _diagnostic: string, + _choices?: number, + _errorsHandlingMode?: ErrorsHandlingMode + ): Promise { + throw Error("I'm a teapot"); + } +} + +class DummyLLMServiceInternal extends LLMServiceInternal< + ModelParams, + DummyLLMService, + DummyGeneratedProof, + DummyLLMServiceInternal +> { + constructGeneratedProof( + _proof: string, + _proofGenerationContext: ProofGenerationContext, + _modelParams: ModelParams, + _previousProofVersions?: ProofVersion[] | undefined + ): DummyGeneratedProof { + throw Error("I'm a teapot"); + } + + async generateFromChatImpl( + _chat: ChatHistory, + _params: ModelParams, + _choices: number + ): Promise { + throw Error("I'm a teapot"); + } +} diff --git a/src/test/llm/llmSpecificTestUtils/eventsTracker.ts b/src/test/llm/llmSpecificTestUtils/eventsTracker.ts new file mode 100644 index 00000000..33a94356 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/eventsTracker.ts @@ -0,0 +1,113 @@ +import { expect } from "earl"; + +import { LLMServiceError } from "../../../llm/llmServiceErrors"; +import { + AnalyzedChatHistory, + ChatHistory, +} from "../../../llm/llmServices/chat"; +import { + LLMService, + LLMServiceImpl, + LLMServiceRequestFailed, + LLMServiceRequestSucceeded, +} from "../../../llm/llmServices/llmService"; + +import { EventLogger } from "../../../logging/eventLogger"; + +import { MockLLMService } from "./mockLLMService"; + +export interface EventsTracker { + successfulRequestEventsN: number; + failedRequestEventsN: number; +} + +export function subscribeToTrackEvents< + LLMServiceType extends LLMService, +>( + testEventLogger: EventLogger, + expectedService: LLMServiceType, + expectedModelId: string, + expectedError?: LLMServiceError +): EventsTracker { + const eventsTracker: EventsTracker = { + successfulRequestEventsN: 0, + failedRequestEventsN: 0, + }; + subscribeToLogicEvents( + eventsTracker, + testEventLogger, + expectedService, + expectedModelId, + expectedError + ); + return eventsTracker; +} + +export interface MockEventsTracker extends EventsTracker { + mockEventsN: number; +} + +export function subscribeToTrackMockEvents( + testEventLogger: EventLogger, + expectedMockService: MockLLMService, + expectedModelId: string, + expectedMockChat?: AnalyzedChatHistory, + expectedError?: LLMServiceError +): MockEventsTracker { + const eventsTracker: MockEventsTracker = { + mockEventsN: 0, + successfulRequestEventsN: 0, + failedRequestEventsN: 0, + }; + testEventLogger.subscribeToLogicEvent( + MockLLMService.generationFromChatEvent, + (chatData) => { + if (expectedMockChat === undefined) { + expect((chatData as ChatHistory) !== null).toBeTruthy(); + } else { + expect(chatData as ChatHistory).toEqual(expectedMockChat.chat); + } + eventsTracker.mockEventsN += 1; + } + ); + subscribeToLogicEvents( + eventsTracker, + testEventLogger, + expectedMockService, + expectedModelId, + expectedError + ); + return eventsTracker; +} + +function subscribeToLogicEvents>( + eventsTracker: EventsTracker, + testEventLogger: EventLogger, + expectedService: LLMServiceType, + expectedModelId: string, + expectedError?: LLMServiceError +) { + testEventLogger.subscribeToLogicEvent( + LLMServiceImpl.requestSucceededEvent, + (data) => { + const requestSucceeded = data as LLMServiceRequestSucceeded; + expect(requestSucceeded).toBeTruthy(); + expect(requestSucceeded.llmService).toEqual(expectedService); + expect(requestSucceeded.params.modelId).toEqual(expectedModelId); + eventsTracker.successfulRequestEventsN += 1; + } + ); + testEventLogger.subscribeToLogicEvent( + LLMServiceImpl.requestFailedEvent, + (data) => { + const requestFailed = data as LLMServiceRequestFailed; + expect(requestFailed).toBeTruthy(); + expect(requestFailed.llmService).toEqual(expectedService); + expect(requestFailed.params.modelId).toEqual(expectedModelId); + if (expectedError !== undefined) { + expect(requestFailed.llmServiceError).toEqual(expectedError); + } + eventsTracker.failedRequestEventsN += 1; + } + ); +} diff --git a/src/test/llm/llmSpecificTestUtils/expectGeneratedProof.ts b/src/test/llm/llmSpecificTestUtils/expectGeneratedProof.ts new file mode 100644 index 00000000..5a4337ce --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/expectGeneratedProof.ts @@ -0,0 +1,37 @@ +import { expect } from "earl"; + +import { ProofVersion } from "../../../llm/llmServices/llmService"; + +import { MockLLMGeneratedProof } from "./mockLLMService"; + +export interface ExpectedGeneratedProof { + proof: string; + versionNumber: number; + proofVersions: ProofVersion[]; + nextVersionCanBeGenerated?: boolean; + canBeFixed?: Boolean; +} + +export function expectGeneratedProof( + actual: MockLLMGeneratedProof, + expected: ExpectedGeneratedProof +) { + expect(actual.proof()).toEqual(expected.proof); + expect(actual.versionNumber()).toEqual(expected.versionNumber); + expect(actual.proofVersions).toEqual(expected.proofVersions); + if (expected.nextVersionCanBeGenerated !== undefined) { + expect(actual.nextVersionCanBeGenerated()).toEqual( + expected.nextVersionCanBeGenerated + ); + } + if (expected.canBeFixed !== undefined) { + expect(actual.canBeFixed()).toEqual(expected.canBeFixed); + } +} + +export function toProofVersion( + proof: string, + diagnostic: string | undefined = undefined +): ProofVersion { + return { proof: proof, diagnostic: diagnostic }; +} diff --git a/src/test/llm/llmSpecificTestUtils/expectLogs.ts b/src/test/llm/llmSpecificTestUtils/expectLogs.ts new file mode 100644 index 00000000..f340374a --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/expectLogs.ts @@ -0,0 +1,50 @@ +import { expect } from "earl"; + +import { LLMService } from "../../../llm/llmServices/llmService"; +import { ResponseStatus } from "../../../llm/llmServices/utils/generationsLogger/loggerRecord"; + +export interface ExpectedRecord { + status: ResponseStatus; + error?: Error; +} + +export function expectLogs( + expectedRecords: ExpectedRecord[], + service: LLMService +) { + const actualRecordsUnwrapped = service + .readGenerationsLogs() + .map((record) => { + return { + status: record.responseStatus, + error: record.error, + }; + }); + const expectedRecordsUnwrapped = expectedRecords.map((record) => { + return { + status: record.status, + error: record.error + ? { + typeName: record.error.name, + message: record.error.message, + } + : undefined, + }; + }); + expect(actualRecordsUnwrapped).toHaveLength( + expectedRecordsUnwrapped.length + ); + // if exact error is not expected, ignore it in the actual records + for (let i = 0; i < expectedRecordsUnwrapped.length; i++) { + const expected = expectedRecordsUnwrapped[i]; + const actual = actualRecordsUnwrapped[i]; + if ( + expected.status === "FAILURE" && + actual.status === "FAILURE" && + expected.error === undefined + ) { + actual.error = undefined; + } + } + expect(actualRecordsUnwrapped).toEqual(expectedRecordsUnwrapped); +} diff --git a/src/test/llm/llmSpecificTestUtils/expectResolutionResult.ts b/src/test/llm/llmSpecificTestUtils/expectResolutionResult.ts new file mode 100644 index 00000000..f7da39a2 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/expectResolutionResult.ts @@ -0,0 +1,92 @@ +import { expect } from "earl"; + +import { + ResolutionActionDetailedResult, + ResolutionActionResult, + SingleParamResolutionResult, +} from "../../../llm/llmServices/utils/paramsResolvers/abstractResolvers"; + +export interface ResolutionResultAddOns { + inputParamName?: string; + resultValue?: T; + isInvalidCause?: string; + inputReadCorrectly?: ResolutionActionResult; + overriden?: ResolutionActionDetailedResult; + resolvedWithDefault?: ResolutionActionResult; +} + +/** + * All values of `actualResolutionResult` are checked for equality to + * the corresponding values of the `expectedResolutionResult`, + * except for the string ones: they are expected to include (!) + * the corresponding values of the `expectedResolutionResult`. + */ +export function expectParamResolutionResult( + actualResolutionResult: SingleParamResolutionResult, + expectedNonDefaultValues: ResolutionResultAddOns, + inputParamName: string | undefined +) { + const expectedResolutionResult = { + inputParamName: inputParamName, + resultValue: undefined, + isInvalidCause: undefined, + inputReadCorrectly: { + wasPerformed: false, + withValue: undefined, + }, + overriden: { + wasPerformed: false, + withValue: undefined, + message: undefined, + }, + resolvedWithDefault: { + wasPerformed: false, + withValue: undefined, + }, + ...expectedNonDefaultValues, + }; + expect(actualResolutionResult.inputParamName).toEqual( + expectedResolutionResult.inputParamName + ); + expect(actualResolutionResult.resultValue).toEqual( + expectedResolutionResult.resultValue + ); + expectMessageValue( + actualResolutionResult.isInvalidCause, + expectedResolutionResult.isInvalidCause + ); + expect(actualResolutionResult.inputReadCorrectly.wasPerformed).toEqual( + expectedResolutionResult.inputReadCorrectly.wasPerformed + ); + expect(actualResolutionResult.inputReadCorrectly.withValue).toEqual( + expectedResolutionResult.inputReadCorrectly.withValue + ); + expect(actualResolutionResult.overriden.wasPerformed).toEqual( + expectedResolutionResult.overriden.wasPerformed + ); + expect(actualResolutionResult.overriden.withValue).toEqual( + expectedResolutionResult.overriden.withValue + ); + expectMessageValue( + actualResolutionResult.overriden.message, + expectedResolutionResult.overriden.message + ); + expect(actualResolutionResult.resolvedWithDefault.wasPerformed).toEqual( + expectedResolutionResult.resolvedWithDefault.wasPerformed + ); + expect(actualResolutionResult.resolvedWithDefault.withValue).toEqual( + expectedResolutionResult.resolvedWithDefault.withValue + ); +} + +export function expectMessageValue( + actualMessage: string | undefined, + expectedMessage: string | undefined +) { + if (expectedMessage === undefined) { + expect(actualMessage).toBeNullish(); + } else { + expect(actualMessage).not.toBeNullish(); + expect(actualMessage!).toInclude(expectedMessage); + } +} diff --git a/src/test/llm/llmSpecificTestUtils/mockLLMService.ts b/src/test/llm/llmSpecificTestUtils/mockLLMService.ts new file mode 100644 index 00000000..e2f2f574 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/mockLLMService.ts @@ -0,0 +1,332 @@ +import { JSONSchemaType } from "ajv"; +import { PropertiesSchema } from "ajv/dist/types/json-schema"; + +import { ConfigurationError } from "../../../llm/llmServiceErrors"; +import { + AnalyzedChatHistory, + ChatHistory, + ChatMessage, +} from "../../../llm/llmServices/chat"; +import { + ErrorsHandlingMode, + GeneratedProofImpl, + LLMServiceImpl, + ProofVersion, +} from "../../../llm/llmServices/llmService"; +import { LLMServiceInternal } from "../../../llm/llmServices/llmServiceInternal"; +import { + ModelParams, + modelParamsSchema, +} from "../../../llm/llmServices/modelParams"; +import { BasicModelParamsResolver } from "../../../llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers"; +import { ValidationRules } from "../../../llm/llmServices/utils/paramsResolvers/builders"; +import { ValidParamsResolverImpl } from "../../../llm/llmServices/utils/paramsResolvers/paramsResolverImpl"; +import { ProofGenerationContext } from "../../../llm/proofGenerationContext"; +import { UserModelParams } from "../../../llm/userModelParams"; + +import { EventLogger } from "../../../logging/eventLogger"; + +export interface MockLLMUserModelParams extends UserModelParams { + proofsToGenerate: string[]; + workerId?: number; +} + +export interface MockLLMModelParams extends ModelParams { + proofsToGenerate: string[]; + workerId: number; + resolvedWithMockLLMService: boolean; +} + +export const mockLLMModelParamsSchema: JSONSchemaType = { + title: "MockLLMModelsParameters", + type: "object", + properties: { + proofsToGenerate: { + type: "array", + items: { type: "string" }, + }, + workerId: { type: "number" }, + resolvedWithMockLLMService: { type: "boolean" }, + ...(modelParamsSchema.properties as PropertiesSchema), + }, + required: [ + "proofsToGenerate", + "workerId", + "resolvedWithMockLLMService", + ...modelParamsSchema.required, + ], + additionalProperties: false, +}; + +/** + * `MockLLMService` parameters resolution does 4 changes to `inputParams`: + * - resolves undefined `workerId` to 0; + * - adds extra `resolvedWithMockLLMService: true` property; + * - overrides original `systemPrompt` with `this.systemPromptToOverrideWith`; + * - overrides original `choices` to `defaultChoices` with `proofsToGenerate.length`. + */ +export class MockLLMModelParamsResolver + extends BasicModelParamsResolver + implements + ValidParamsResolverImpl +{ + constructor() { + super(mockLLMModelParamsSchema, "MockLLMModelParams"); + } + + readonly proofsToGenerate = this.resolveParam("proofsToGenerate") + .requiredToBeConfigured() + .validate([(value) => value.length > 0, "be non-empty"]); + + readonly workerId = this.resolveParam("workerId") + .default(() => 0) + .validate([(value) => value >= 0, "be non-negative"]); + + readonly resolvedWithMockLLMService = this.insertParam( + () => true + ).validate([(value) => value, "be true"]); + + readonly systemPrompt = this.resolveParam("systemPrompt") + .override(() => MockLLMService.systemPromptToOverrideWith) + .requiredToBeConfigured() + .noValidationNeeded(); + + readonly defaultChoices = this.resolveParam("choices") + .override((inputParams) => inputParams.proofsToGenerate.length) + .requiredToBeConfigured() + .validate(ValidationRules.bePositiveNumber); +} + +/** + * This class implements `LLMService` the same way as most of the services do, + * so as to reuse the default implementations as much as possible. + * + * However, to make tests cover more corner cases, `MockLLMService` provides additional features. + * Check the documentation of its methods below. + */ +export class MockLLMService extends LLMServiceImpl< + MockLLMUserModelParams, + MockLLMModelParams, + MockLLMService, + MockLLMGeneratedProof, + MockLLMServiceInternal +> { + protected readonly internal: MockLLMServiceInternal; + protected readonly modelParamsResolver = new MockLLMModelParamsResolver(); + + constructor( + eventLogger?: EventLogger, + debugLogs: boolean = false, + generationsLogsFilePath?: string + ) { + super( + "MockLLMService", + eventLogger, + debugLogs, + generationsLogsFilePath + ); + this.internal = new MockLLMServiceInternal( + this, + this.eventLoggerGetter, + this.generationsLoggerBuilder + ); + } + + static readonly generationFromChatEvent = "mockllm-generation-from-chat"; + + static readonly systemPromptToOverrideWith = + "unique mock-llm system prompt"; + + static readonly proofFixPrompt = "Generate `Fixed.` instead of proof."; + static readonly fixedProofString = "Fixed."; + + /** + * Use this method to make 1 next generation (for the specified worker) throw the specified error. + * Workers are meant to be any external entities that would like to separate their behaviour. + */ + throwErrorOnNextGeneration(error: Error, workerId: number = 0) { + this.internal.throwErrorOnNextGenerationMap.set(workerId, error); + } + + /** + * Adds special control message to the chat, so it would make `MockLLMService` + * skip first `skipFirstNProofs` proofs at the generation stage. + */ + transformChatToSkipFirstNProofs( + baseChat: ChatHistory, + skipFirstNProofs: number + ): ChatHistory { + const controlMessage: ChatMessage = { + role: "user", + content: `SKIP_FIRST_PROOFS: ${skipFirstNProofs}`, + }; + return [...baseChat, controlMessage]; + } + + clearGenerationLogs() { + this.internal.generationsLogger.resetLogs(); + } +} + +export class MockLLMGeneratedProof extends GeneratedProofImpl< + MockLLMModelParams, + MockLLMService, + MockLLMGeneratedProof, + MockLLMServiceInternal +> { + constructor( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: MockLLMModelParams, + llmServiceInternal: MockLLMServiceInternal, + previousProofVersions?: ProofVersion[] + ) { + super( + proof, + proofGenerationContext, + modelParams, + llmServiceInternal, + previousProofVersions + ); + } + + /** + * Mocks the procces of the implementation of a new regeneration method. + * Namely, checks whether it is possible. + */ + nextVersionCanBeGenerated(): Boolean { + return super.nextVersionCanBeGenerated(); + } + + /** + * Mocks the process of the implementation of a new regeneration method. + * Namely, performs the generation using `LLMServiceInternal.generateFromChatWrapped`. + */ + async generateNextVersion( + analyzedChat: AnalyzedChatHistory, + choices: number, + errorsHandlingMode: ErrorsHandlingMode + ): Promise { + return this.llmServiceInternal.generateFromChatWrapped( + this.modelParams, + choices, + errorsHandlingMode, + () => { + if (!this.nextVersionCanBeGenerated()) { + throw new ConfigurationError( + `next version could not be generated: version ${this.versionNumber()} >= max rounds number ${this.maxRoundsNumber}` + ); + } + return analyzedChat; + }, + (proof: string) => + this.llmServiceInternal.constructGeneratedProof( + proof, + this.proofGenerationContext, + this.modelParams, + this.proofVersions + ) + ); + } +} + +class MockLLMServiceInternal extends LLMServiceInternal< + MockLLMModelParams, + MockLLMService, + MockLLMGeneratedProof, + MockLLMServiceInternal +> { + throwErrorOnNextGenerationMap: Map = new Map(); + + constructGeneratedProof( + proof: string, + proofGenerationContext: ProofGenerationContext, + modelParams: MockLLMModelParams, + previousProofVersions?: ProofVersion[] | undefined + ): MockLLMGeneratedProof { + return new MockLLMGeneratedProof( + proof, + proofGenerationContext, + modelParams as MockLLMModelParams, + this, + previousProofVersions + ); + } + + /** + * Generally, `generateFromChatImpl` simply returns first `choices` proofs from the `MockLLMModelParams.proofsToGenerate`. + * Each `generateFromChatImpl` call sends logic `this.generationFromChatEvent` event to the `eventLogger`. + * Special behaviour: + * - If `throwErrorOnNextGenereation` was registered for `MockLLMModelParams.workerId`, + * `generateFromChatImpl` throws this error and then resets this behaviour for the next call. + * - If `chat` contains special control message (see `transformChatToSkipFirstNProofs`), + * several proofs from the beggining of `MockLLMModelParams.proofsToGenerate` will be skipped. + * Practically, it provides a way to generate different proofs depending on the `chat` (while `modelParams` stay the same). + * - If `chat` contains `this.proofFixPrompt` in any of its messages, + * then all the generated proofs will be equal to `this.fixedProofString`. + */ + async generateFromChatImpl( + chat: ChatHistory, + params: MockLLMModelParams, + choices: number + ): Promise { + this.eventLogger?.logLogicEvent( + MockLLMService.generationFromChatEvent, + chat + ); + + const throwError = this.throwErrorOnNextGenerationMap.get( + params.workerId + ); + if (throwError !== undefined) { + try { + throw throwError; + } finally { + this.throwErrorOnNextGenerationMap.set( + params.workerId, + undefined + ); + } + } + + const proofFixPromptInChat = chat.find( + (message) => message.content === MockLLMService.proofFixPrompt + ); + if (proofFixPromptInChat !== undefined) { + return Array(choices).fill(MockLLMService.fixedProofString); + } + + const lastChatMessage = chat[chat.length - 1]; + const skipFirstNProofsParsed = + this.parseSkipFirstNProofsIfMatches(lastChatMessage); + const skipFirstNProofs = + skipFirstNProofsParsed !== undefined ? skipFirstNProofsParsed : 0; + + const proofsLength = params.proofsToGenerate.length - skipFirstNProofs; + if (choices > proofsLength) { + throw Error( + `\`choices = ${choices}\` > \`available proofs length = ${proofsLength}\`` + ); + } + + return params.proofsToGenerate.slice( + skipFirstNProofs, + skipFirstNProofs + choices + ); + } + + private readonly skipFirstNProofsContentPattern = + /^SKIP_FIRST_PROOFS: (.*)$/; + + private parseSkipFirstNProofsIfMatches( + message: ChatMessage + ): number | undefined { + const match = message.content.match( + this.skipFirstNProofsContentPattern + ); + if (!match) { + return undefined; + } + return parseInt(match[1]); + } +} diff --git a/src/test/llm/llmSpecificTestUtils/modelParamsAddOns.ts b/src/test/llm/llmSpecificTestUtils/modelParamsAddOns.ts new file mode 100644 index 00000000..bd3825f5 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/modelParamsAddOns.ts @@ -0,0 +1,47 @@ +import { MultiroundProfile } from "../../../llm/llmServices/modelParams"; +import { UserMultiroundProfile } from "../../../llm/userModelParams"; + +export interface UserModelParamsAddOns { + modelId?: string; + choices?: number; + systemPrompt?: string; + maxTokensToGenerate?: number; + tokensLimit?: number; + multiroundProfile?: UserMultiroundProfile; +} + +export interface PredefinedProofsUserModelParamsAddOns + extends UserModelParamsAddOns { + tactics?: string[]; +} + +export interface OpenAiUserModelParamsAddOns extends UserModelParamsAddOns { + modelName?: string; + temperature?: number; + apiKey?: string; +} + +export interface GrazieUserModelParamsAddOns extends UserModelParamsAddOns { + modelName?: string; + apiKey?: string; +} + +export interface LMStudioUserModelParamsAddOns extends UserModelParamsAddOns { + temperature?: number; + port?: number; +} + +export interface ModelParamsAddOns { + modelId?: string; + choices?: number; + systemPrompt?: string; + maxTokensToGenerate?: number; + tokensLimit?: number; + multiroundProfile?: MultiroundProfile; +} + +export interface MultiroundProfileAddOns { + maxRoundsNumber?: number; + defaultProofFixChoices?: number; + proofFixPrompt?: string; +} diff --git a/src/test/llm/llmSpecificTestUtils/testAdmitCompletion.ts b/src/test/llm/llmSpecificTestUtils/testAdmitCompletion.ts new file mode 100644 index 00000000..c19477a9 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/testAdmitCompletion.ts @@ -0,0 +1,45 @@ +import { expect } from "earl"; + +import { + ErrorsHandlingMode, + LLMService, +} from "../../../llm/llmServices/llmService"; +import { ModelParams } from "../../../llm/llmServices/modelParams"; +import { UserModelParams } from "../../../llm/userModelParams"; + +import { checkTheoremProven } from "../../commonTestFunctions/checkProofs"; +import { prepareEnvironmentWithContexts } from "../../commonTestFunctions/prepareEnvironment"; +import { withLLMServiceAndParams } from "../../commonTestFunctions/withLLMService"; + +export async function testLLMServiceCompletesAdmitFromFile< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, +>( + service: LLMService, + inputParams: InputModelParams, + resourcePath: string[], + choices: number +) { + return withLLMServiceAndParams( + service, + inputParams, + async (service, resolvedParams: ResolvedModelParams) => { + const [environment, [[completionContext, proofGenerationContext]]] = + await prepareEnvironmentWithContexts(resourcePath); + const generatedProofs = await service.generateProof( + proofGenerationContext, + resolvedParams, + choices, + ErrorsHandlingMode.RETHROW_ERRORS + ); + expect(generatedProofs).toHaveLength(choices); + expect( + checkTheoremProven( + generatedProofs, + completionContext, + environment + ) + ).toBeTruthy(); + } + ); +} diff --git a/src/test/llm/llmSpecificTestUtils/testFailedGeneration.ts b/src/test/llm/llmSpecificTestUtils/testFailedGeneration.ts new file mode 100644 index 00000000..bf3ff553 --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/testFailedGeneration.ts @@ -0,0 +1,267 @@ +import { expect } from "earl"; + +import { + ConfigurationError, + GenerationFailedError, + LLMServiceError, +} from "../../../llm/llmServiceErrors"; +import { AnalyzedChatHistory } from "../../../llm/llmServices/chat"; +import { ErrorsHandlingMode } from "../../../llm/llmServices/llmService"; + +import { subscribeToTrackMockEvents } from "./eventsTracker"; +import { ExpectedRecord, expectLogs } from "./expectLogs"; +import { MockLLMModelParams, MockLLMService } from "./mockLLMService"; +import { withMockLLMService } from "./withMockLLMService"; + +export function testFailedGenerationCompletely( + generate: ( + mockService: MockLLMService, + mockParams: MockLLMModelParams, + errorsHandlingMode: ErrorsHandlingMode, + preparedData?: T + ) => Promise, + additionalTestParams: any = {}, + prepareDataBeforeTest?: ( + mockService: MockLLMService, + basicMockParams: MockLLMModelParams + ) => Promise +) { + buildErrorsWithExpectedHandling().forEach( + ({ error, expectedThrownError, expectedLogs }) => { + const commonTestParams = { + failureName: error.constructor.name, + expectedGenerationLogs: expectedLogs, + + errorToThrow: error, + expectedErrorOfFailedEvent: expectedThrownError, + + shouldFailBeforeGenerationIsStarted: false, + + ...additionalTestParams, + }; + testFailedGeneration( + { + ...commonTestParams, + errorsHandlingMode: + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS, + expectedFailedRequestEventsN: 1, + }, + generate, + prepareDataBeforeTest + ); + + testFailedGeneration( + { + ...commonTestParams, + errorsHandlingMode: ErrorsHandlingMode.RETHROW_ERRORS, + // `ErrorsHandlingMode.RETHROW_ERRORS` doesn't use failed-generation events + expectedFailedRequestEventsN: 0, + + expectedThrownError: expectedThrownError, + }, + generate, + prepareDataBeforeTest + ); + } + ); +} + +export interface ErrorWithExpectedHandling { + error: Error; + expectedThrownError: LLMServiceError; + expectedLogs: ExpectedRecord[]; +} + +export function buildErrorsWithExpectedHandling(): ErrorWithExpectedHandling[] { + const internalError = Error("internal generation error"); + const configurationError = new ConfigurationError( + "something is wrong with params" + ); + const generationFailedError = new GenerationFailedError( + Error("implementation decided to throw wrapped error by itself") + ); + + return [ + { + error: internalError, + expectedThrownError: new GenerationFailedError(internalError), + expectedLogs: [{ status: "FAILURE", error: internalError }], + }, + { + error: configurationError, + expectedThrownError: configurationError, + expectedLogs: [], + }, + { + error: generationFailedError, + expectedThrownError: generationFailedError, + expectedLogs: [ + { status: "FAILURE", error: generationFailedError.cause }, + ], + }, + ]; +} + +export function testFailureAtChatBuilding( + generate: ( + mockService: MockLLMService, + mockParams: MockLLMModelParams, + errorsHandlingMode: ErrorsHandlingMode, + preparedData?: T + ) => Promise, + additionalTestParams: any = {}, + prepareDataBeforeTest?: ( + mockService: MockLLMService, + basicMockParams: MockLLMModelParams + ) => Promise +) { + testFailedGeneration( + { + errorsHandlingMode: + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS, + failureName: "failure at chat building", + expectedGenerationLogs: [], + expectedFailedRequestEventsN: 1, + buildErroneousMockParams: (basicMockParams: MockLLMModelParams) => { + return { + ...basicMockParams, + maxTokensToGenerate: 100, + tokensLimit: 10, + }; + }, + shouldFailBeforeGenerationIsStarted: true, + ...additionalTestParams, + }, + generate, + prepareDataBeforeTest + ); +} + +export interface FailedGenerationTestParams { + errorsHandlingMode: ErrorsHandlingMode; + failureName: string; + expectedGenerationLogs: ExpectedRecord[]; + expectedFailedRequestEventsN: number; + + errorToThrow?: Error; + expectedErrorOfFailedEvent?: LLMServiceError; + expectedThrownError?: LLMServiceError; + + expectedChatOfMockEvent?: AnalyzedChatHistory; + + buildErroneousMockParams?: ( + basicMockParams: MockLLMModelParams + ) => MockLLMModelParams; + shouldFailBeforeGenerationIsStarted: boolean; + + proofsToGenerate?: string[]; + + testTargetName?: string; +} + +export function testFailedGeneration( + testParams: FailedGenerationTestParams, + generate: ( + mockService: MockLLMService, + mockParams: MockLLMModelParams, + errorsHandlingMode: ErrorsHandlingMode, + preparedData?: T + ) => Promise, + prepareDataBeforeTest?: ( + mockService: MockLLMService, + basicMockParams: MockLLMModelParams + ) => Promise +) { + const testNamePrefix = + testParams.testTargetName ?? "Test failed generation"; + test(`${testNamePrefix}: ${testParams.failureName}, ${testParams.errorsHandlingMode}`, async () => { + await withMockLLMService( + async (mockService, basicMockParams, testEventLogger) => { + const preparedData = + prepareDataBeforeTest !== undefined + ? await prepareDataBeforeTest( + mockService, + basicMockParams + ) + : undefined; + + const eventsTracker = subscribeToTrackMockEvents( + testEventLogger, + mockService, + basicMockParams.modelId, + testParams.expectedChatOfMockEvent, + testParams.expectedErrorOfFailedEvent + ); + + if (testParams.errorToThrow !== undefined) { + mockService.throwErrorOnNextGeneration( + testParams.errorToThrow + ); + } + const maybeErroneousMockParams = + testParams.buildErroneousMockParams !== undefined + ? testParams.buildErroneousMockParams(basicMockParams) + : basicMockParams; + + try { + const noGeneratedProofs = await generate( + mockService, + maybeErroneousMockParams, + testParams.errorsHandlingMode, + preparedData + ); + expect(testParams.errorsHandlingMode).toEqual( + ErrorsHandlingMode.LOG_EVENTS_AND_SWALLOW_ERRORS + ); + expect(noGeneratedProofs).toHaveLength(0); + } catch (e) { + expect(testParams.errorsHandlingMode).toEqual( + ErrorsHandlingMode.RETHROW_ERRORS + ); + expect(e as LLMServiceError).toBeTruthy(); + if (testParams.expectedThrownError !== undefined) { + expect(e).toEqual(testParams.expectedThrownError); + } + } + + const expectedMockEventsN = + testParams.shouldFailBeforeGenerationIsStarted ? 0 : 1; + expect(eventsTracker).toEqual({ + mockEventsN: expectedMockEventsN, + successfulRequestEventsN: 0, + failedRequestEventsN: + testParams.expectedFailedRequestEventsN, + }); + expectLogs(testParams.expectedGenerationLogs, mockService); + + // check if service stays available after an error occurred + const generatedProofs = await generate( + mockService, + basicMockParams, + testParams.errorsHandlingMode, + preparedData + ); + if (testParams.proofsToGenerate !== undefined) { + expect(generatedProofs).toEqual( + testParams.proofsToGenerate + ); + } + + expect(eventsTracker).toEqual({ + mockEventsN: expectedMockEventsN + 1, + successfulRequestEventsN: 1, + failedRequestEventsN: + testParams.expectedFailedRequestEventsN, + }); + // `mockLLM` was created with `debugLogs = true`, so logs are not cleaned on success + expectLogs( + [ + ...testParams.expectedGenerationLogs, + { status: "SUCCESS" }, + ], + mockService + ); + } + ); + }); +} diff --git a/src/test/llm/llmSpecificTestUtils/testResolveParameters.ts b/src/test/llm/llmSpecificTestUtils/testResolveParameters.ts new file mode 100644 index 00000000..7358ed6d --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/testResolveParameters.ts @@ -0,0 +1,59 @@ +import { expect } from "earl"; + +import { LLMService } from "../../../llm/llmServices/llmService"; +import { ModelParams } from "../../../llm/llmServices/modelParams"; +import { defaultMultiroundProfile } from "../../../llm/llmServices/utils/paramsResolvers/basicModelParamsResolvers"; +import { + UserModelParams, + UserMultiroundProfile, +} from "../../../llm/userModelParams"; + +/** + * "User" version of `defaultUserMultiroundProfile` having the same default values. + * This constant is needed because `proofFixChoices` and `defaultProofFixChoices` + * parameters have different names. + */ +export const defaultUserMultiroundProfile: UserMultiroundProfile = { + maxRoundsNumber: defaultMultiroundProfile.maxRoundsNumber, + proofFixChoices: defaultMultiroundProfile.defaultProofFixChoices, + proofFixPrompt: defaultMultiroundProfile.proofFixPrompt, +}; + +export function testResolveValidCompleteParameters< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, +>( + llmService: LLMService, + validInputParams: InputModelParams, + expectNoDefaultResolutions: boolean = false +) { + const resolutionResult = llmService.resolveParameters(validInputParams); + expect(resolutionResult.resolved).not.toBeNullish(); + // verify logs + for (const paramLog of resolutionResult.resolutionLogs) { + expect(paramLog.resultValue).not.toBeNullish(); + expect(paramLog.isInvalidCause).toBeNullish(); + if (expectNoDefaultResolutions) { + expect(paramLog.inputReadCorrectly.wasPerformed).toBeTruthy(); + expect(paramLog.resolvedWithDefault.wasPerformed).toBeFalsy(); + } + } +} + +export function testResolveParametersFailsWithSingleCause< + InputModelParams extends UserModelParams, + ResolvedModelParams extends ModelParams, +>( + llmService: LLMService, + invalidInputParams: InputModelParams, + invalidCauseSubstring: string +) { + const resolutionResult = llmService.resolveParameters(invalidInputParams); + expect(resolutionResult.resolved).toBeNullish(); + const failureLogs = resolutionResult.resolutionLogs.filter( + (paramLog) => paramLog.isInvalidCause !== undefined + ); + expect(failureLogs).toHaveLength(1); + const invalidCause = failureLogs[0].isInvalidCause; + expect(invalidCause?.includes(invalidCauseSubstring)).toBeTruthy(); +} diff --git a/src/test/llm/llmSpecificTestUtils/transformParams.ts b/src/test/llm/llmSpecificTestUtils/transformParams.ts new file mode 100644 index 00000000..2dbcaa6a --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/transformParams.ts @@ -0,0 +1,17 @@ +import { MockLLMModelParams } from "./mockLLMService"; +import { MultiroundProfileAddOns } from "./modelParamsAddOns"; + +export function enhanceMockParams( + basicMockParams: MockLLMModelParams, + multiroundProfile: MultiroundProfileAddOns = {}, + unlimitedTokens: boolean = true +): MockLLMModelParams { + return { + ...basicMockParams, + tokensLimit: unlimitedTokens ? 100000 : basicMockParams.tokensLimit, + multiroundProfile: { + ...basicMockParams.multiroundProfile, + ...multiroundProfile, + }, + }; +} diff --git a/src/test/llm/llmSpecificTestUtils/withMockLLMService.ts b/src/test/llm/llmSpecificTestUtils/withMockLLMService.ts new file mode 100644 index 00000000..f745b35b --- /dev/null +++ b/src/test/llm/llmSpecificTestUtils/withMockLLMService.ts @@ -0,0 +1,36 @@ +import { EventLogger } from "../../../logging/eventLogger"; +import { withLLMService } from "../../commonTestFunctions/withLLMService"; + +import { proofsToGenerate, testModelId } from "./constants"; +import { MockLLMModelParams, MockLLMService } from "./mockLLMService"; + +export async function withMockLLMService( + block: ( + mockService: MockLLMService, + basicMockParams: MockLLMModelParams, + testEventLogger: EventLogger + ) => Promise +) { + const testEventLogger = new EventLogger(); + return withLLMService( + new MockLLMService(testEventLogger, true), + async (mockService) => { + const basicMockParams: MockLLMModelParams = { + modelId: testModelId, + systemPrompt: MockLLMService.systemPromptToOverrideWith, + maxTokensToGenerate: 100, + tokensLimit: 1000, + multiroundProfile: { + maxRoundsNumber: 1, + defaultProofFixChoices: 0, + proofFixPrompt: "Fix proof", + }, + defaultChoices: proofsToGenerate.length, + proofsToGenerate: proofsToGenerate, + workerId: 0, + resolvedWithMockLLMService: true, + }; + await block(mockService, basicMockParams, testEventLogger); + } + ); +} diff --git a/src/test/llm/parseUserModelParams.test.ts b/src/test/llm/parseUserModelParams.test.ts new file mode 100644 index 00000000..18cf93f1 --- /dev/null +++ b/src/test/llm/parseUserModelParams.test.ts @@ -0,0 +1,174 @@ +import { JSONSchemaType } from "ajv"; +import { expect } from "earl"; + +import { + grazieUserModelParamsSchema, + lmStudioUserModelParamsSchema, + openAiUserModelParamsSchema, + predefinedProofsUserModelParamsSchema, + userModelParamsSchema, + userMultiroundProfileSchema, +} from "../../llm/userModelParams"; + +import { AjvMode, buildAjv } from "../../utils/ajvErrorsHandling"; + +suite("Parse `UserModelParams` from JSON test", () => { + const jsonSchemaValidator = buildAjv(AjvMode.COLLECT_ALL_ERRORS); + + function validateJSON( + json: any, + targetClassSchema: JSONSchemaType, + expectedToBeValidJSON: boolean, + expectedErrorKeys?: string[] + ) { + const validate = jsonSchemaValidator.compile(targetClassSchema); + expect(validate(json as T)).toEqual(expectedToBeValidJSON); + if (expectedErrorKeys !== undefined) { + expect(validate.errors).not.toBeNullish(); + expect( + new Set(validate.errors!.map((error) => error.keyword)) + ).toEqual(new Set(expectedErrorKeys)); + } + } + + function isValidJSON(json: any, targetClassSchema: JSONSchemaType) { + return validateJSON(json, targetClassSchema, true); + } + + function isInvalidJSON( + json: any, + targetClassSchema: JSONSchemaType, + ...expectedErrorKeys: string[] + ) { + return validateJSON(json, targetClassSchema, false, expectedErrorKeys); + } + + const validMultiroundProfileComplete = { + maxRoundsNumber: 5, + proofFixChoices: 1, + proofFixPrompt: "fix me", + }; + + const validUserModelParamsCompelete = { + modelId: "unique model id", + choices: 30, + systemPrompt: "generate proof", + maxTokensToGenerate: 100, + tokensLimit: 2000, + multiroundProfile: validMultiroundProfileComplete, + }; + + const validPredefinedProofsUserModelParamsComplete = { + ...validUserModelParamsCompelete, + tactics: ["auto.", "auto. intro."], + }; + const validOpenAiUserModelParamsComplete = { + ...validUserModelParamsCompelete, + modelName: "gpt-model", + temperature: 36.6, + apiKey: "api-key", + }; + const validGrazieUserModelParamsComplete = { + ...validUserModelParamsCompelete, + modelName: "gpt-model", + apiKey: "api-key", + }; + const validLMStudioUserModelParamsComplete = { + ...validUserModelParamsCompelete, + temperature: 36.6, + port: 555, + }; + + test("Validate `UserMultiroundProfile`", () => { + isValidJSON( + validMultiroundProfileComplete, + userMultiroundProfileSchema + ); + const validUndefinedProp = { + maxRoundsNumber: 5, + proofFixChoices: 1, + proofFixPrompt: undefined, + }; + isValidJSON(validUndefinedProp, userMultiroundProfileSchema); + const validEmpty = {}; + isValidJSON(validEmpty, userMultiroundProfileSchema); + + const invalidWrongTypeProp = { + ...validMultiroundProfileComplete, + proofFixPrompt: 0, + }; + isInvalidJSON( + invalidWrongTypeProp, + userMultiroundProfileSchema, + "type" + ); + + const invalidAdditionalProp = { + ...validMultiroundProfileComplete, + something: "something", + }; + isInvalidJSON( + invalidAdditionalProp, + userMultiroundProfileSchema, + "additionalProperties" + ); + }); + + test("Validate `UserModelParams`", () => { + isValidJSON(validUserModelParamsCompelete, userModelParamsSchema); + const validOnlyModelId = { + modelId: "the only id", + }; + isValidJSON(validOnlyModelId, userModelParamsSchema); + + const invalidNoModelId = { + choices: 30, + systemPrompt: "let's generate", + }; + isInvalidJSON(invalidNoModelId, userModelParamsSchema, "required"); + + const invalidWrongTypeProp = { + ...validUserModelParamsCompelete, + tokensLimit: "no limits", + }; + isInvalidJSON(invalidWrongTypeProp, userModelParamsSchema, "type"); + + const invalidAdditionalProp = { + ...validUserModelParamsCompelete, + something: "something", + }; + isInvalidJSON( + invalidAdditionalProp, + userModelParamsSchema, + "additionalProperties" + ); + }); + + test("Validate `PredefinedProofsUserModelParams`", () => { + isValidJSON( + validPredefinedProofsUserModelParamsComplete, + predefinedProofsUserModelParamsSchema + ); + }); + + test("Validate `OpenAiUserModelParams`", () => { + isValidJSON( + validOpenAiUserModelParamsComplete, + openAiUserModelParamsSchema + ); + }); + + test("Validate `GrazieUserModelParams`", () => { + isValidJSON( + validGrazieUserModelParamsComplete, + grazieUserModelParamsSchema + ); + }); + + test("Validate `LMStudioUserModelParams`", () => { + isValidJSON( + validLMStudioUserModelParamsComplete, + lmStudioUserModelParamsSchema + ); + }); +}); diff --git a/src/test/resources/build_chat_theorems.v b/src/test/resources/build_chat_theorems.v new file mode 100644 index 00000000..4555c152 --- /dev/null +++ b/src/test/resources/build_chat_theorems.v @@ -0,0 +1,17 @@ +Theorem plus : forall n:nat, 1 + n = S n. +Proof. + auto. +Qed. + +Theorem plus_assoc : forall a b c, a + (b + c) = a + b + c. +Proof. + intros. + induction a. + reflexivity. + simpl. rewrite IHa. reflexivity. +Qed. + +Theorem test : forall (A : Type) (P : A -> Prop) (x : A), P x -> P x. +Proof. + admit. +Admitted. \ No newline at end of file diff --git a/src/utils/ajvErrorsHandling.ts b/src/utils/ajvErrorsHandling.ts new file mode 100644 index 00000000..86fe4bd7 --- /dev/null +++ b/src/utils/ajvErrorsHandling.ts @@ -0,0 +1,53 @@ +import Ajv, { DefinedError, Options } from "ajv"; + +import { stringifyDefinedValue } from "./printers"; + +export enum AjvMode { + RETURN_AFTER_FIRST_ERROR, + COLLECT_ALL_ERRORS, +} + +export function buildAjv(mode: AjvMode): Ajv { + const ajvOptions: Options = + mode === AjvMode.RETURN_AFTER_FIRST_ERROR ? {} : { allErrors: true }; + return new Ajv(ajvOptions); +} + +export function ajvErrorsAsString( + errorObjects: DefinedError[], + ignoreErrorsWithKeywords: string[] = [] +) { + const errorsToReport = errorObjects.filter( + (errorObject) => !ignoreErrorsWithKeywords.includes(errorObject.keyword) + ); + return errorsToReport.map((error) => stringifyAjvError(error)).join("; "); +} + +export function stringifyAjvError(error: DefinedError): string { + // To support more keywords, check https://ajv.js.org/api.html#error-parameters. + switch (error.keyword) { + case "required": + return `required property ${buildQualifiedPropertyName(error.instancePath, error.params.missingProperty)} is missing`; + case "additionalProperties": + return `unknown property ${buildQualifiedPropertyName(error.instancePath, error.params.additionalProperty)}`; + case "type": + return `${buildQualifiedPropertyName(error.instancePath)} property must be of type "${error.params.type}"`; + case "oneOf": + return `${buildQualifiedPropertyName(error.instancePath)} property must match exactly one of the specified schemas`; + default: + return stringifyDefinedValue(error); + } +} + +function buildQualifiedPropertyName( + instancePath: string, + propertySimpleName?: string +): string { + const qualifiedPath = ( + instancePath.startsWith("/") ? instancePath.substring(1) : instancePath + ).replace("/", "."); + const name = propertySimpleName === undefined ? "" : propertySimpleName; + return qualifiedPath !== "" && name !== "" + ? `\`${qualifiedPath}.${name}\`` + : `\`${qualifiedPath}${name}\``; +} diff --git a/src/utils/printers.ts b/src/utils/printers.ts new file mode 100644 index 00000000..674da725 --- /dev/null +++ b/src/utils/printers.ts @@ -0,0 +1,14 @@ +export function stringifyAnyValue(value: any): string { + const valueAsString = JSON.stringify(value); + if (typeof value === "number") { + return valueAsString; + } + return `"${valueAsString}"`; +} + +export function stringifyDefinedValue(value: any): string { + if (value === undefined) { + throw Error(`value to stringify is not defined`); + } + return stringifyAnyValue(value); +} diff --git a/src/utils/simpleSet.ts b/src/utils/simpleSet.ts new file mode 100644 index 00000000..b2f13169 --- /dev/null +++ b/src/utils/simpleSet.ts @@ -0,0 +1,15 @@ +export class SimpleSet { + private readonly entities: Map = new Map(); + + constructor( + private readonly keyExtractor: (entity: EntityType) => KeyType + ) {} + + add(entity: EntityType) { + this.entities.set(this.keyExtractor(entity), true); + } + + has(entity: EntityType): boolean { + return this.entities.has(this.keyExtractor(entity)); + } +}