Skip to content

Commit

Permalink
Gemini Enhancements (#428)
Browse files Browse the repository at this point in the history
* chore: Bump google-generativeai and related dependencies

* feat: add support for --temperature option to gemini

* feat: add support for --interval option to gemini

* feat: add support for --model_list option to gemini

* feat: add support for --prompt option to gemini

* modify: model settings

* feat: add support for --use_context option to gemini

* feat: add support for rotate_key to gemini

* feat: add exponential backoff to gemini

* Update README.md

* fix: typos and apply black formatting

* Update make_test_ebook.yaml

* fix: cli

* fix: interval option implementation

* fix: interval for geminipro

* fix: recreate convo after rotating key
  • Loading branch information
risin42 authored Oct 21, 2024
1 parent 6912206 commit 9261d92
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 53 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/make_test_ebook.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ jobs:
- name: Rename and Upload ePub
if: env.OPENAI_API_KEY != null
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v3
with:
name: epub_output
path: "test_books/lemo_bilingual.epub"
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ Find more info here for using liteLLM: https://github.com/BerriAI/litellm/blob/m
- If using chatgptapi, you can add `--use_context` to add a context paragraph to each passage sent to the model for translation (see below).
- Support DeepL model [DeepL Translator](https://rapidapi.com/splintPRO/api/dpl-translator) need pay to get the token use `--model deepl --deepl_key ${deepl_key}`
- Support DeepL free model `--model deeplfree`
- Support Google [Gemini](https://makersuite.google.com/app/apikey) model `--model gemini --gemini_key ${gemini_key}`
- Support Google [Gemini](https://aistudio.google.com/app/apikey) model, use `--model gemini` for Gemini Flash or `--model geminipro` for Gemini Pro. `--gemini_key ${gemini_key}`
- If you want to use a specific model alias with Gemini (eg `gemini-1.5-flash-002` or `gemini-1.5-flash-8b-exp-0924`), you can use `--model gemini --model_list gemini-1.5-flash-002,gemini-1.5-flash-8b-exp-0924`. `--model_list` takes a comma-separated list of model aliases.
- Support [Claude](https://console.anthropic.com/docs) model, use `--model claude --claude_key ${claude_key}`
- Support [Tencent TranSmart](https://transmart.qq.com) model (Free), use `--model tencentransmart`
- Support [Ollama](https://github.com/ollama/ollama) self-host models, use `--ollama_model ${ollama_model_name}`
Expand All @@ -57,7 +58,7 @@ Find more info here for using liteLLM: https://github.com/BerriAI/litellm/blob/m
- `--accumulated_num` Wait for how many tokens have been accumulated before starting the translation. gpt3.5 limits the total_token to 4090. For example, if you use --accumulated_num 1600, maybe openai will
output 2200 tokens and maybe 200 tokens for other messages in the system messages user messages, 1600+2200+200=4000, So you are close to reaching the limit. You have to choose your own
value, there is no way to know if the limit is reached before sending
- `--use_context` prompts the model to create a three-paragraph summary. If it's the beginning of the translation, it will summarize the entire passage sent (the size depending on `--accumulated_num`). For subsequent passages, it will amend the summary to include details from the most recent passage, creating a running one-paragraph context payload of the important details of the entire translated work. This improves consistency of flow and tone throughout the translation. This option is available for all ChatGPT-compatible models.
- `--use_context` prompts the model to create a three-paragraph summary. If it's the beginning of the translation, it will summarize the entire passage sent (the size depending on `--accumulated_num`). For subsequent passages, it will amend the summary to include details from the most recent passage, creating a running one-paragraph context payload of the important details of the entire translated work. This improves consistency of flow and tone throughout the translation. This option is available for all ChatGPT-compatible models and Gemini models.
- Use `--context_paragraph_limit` to set a limit on the number of context paragraphs when using the `--use_context` option.
- Use `--temperature` to set the temperature parameter for `chatgptapi`/`gpt4`/`claude` models. For example: `--temperature 0.7`.
- Use `--block_size` to merge multiple paragraphs into one block. This may increase accuracy and speed up the process but can disturb the original format. Must be used with `--single_translate`. For example: `--block_size 5`.
Expand All @@ -82,9 +83,12 @@ python3 make_book.py --book_name test_books/Lex_Fridman_episode_322.srt --openai
# Or translate the whole book
python3 make_book.py --book_name test_books/animal_farm.epub --openai_key ${openai_key} --language zh-hans

# Or translate the whole book using Gemini
# Or translate the whole book using Gemini flash
python3 make_book.py --book_name test_books/animal_farm.epub --gemini_key ${gemini_key} --model gemini

# Use a specific list of Gemini model aliases
python3 make_book.py --book_name test_books/animal_farm.epub --gemini_key ${gemini_key} --model gemini --model_list gemini-1.5-flash-002,gemini-1.5-flash-8b-exp-0924

# Set env OPENAI_API_KEY to ignore option --openai_key
export OPENAI_API_KEY=${your_api_key}

Expand Down
20 changes: 18 additions & 2 deletions book_maker/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def main():
"--temperature",
type=float,
default=1.0,
help="temperature parameter for `chatgptapi`/`gpt4`/`claude`",
help="temperature parameter for `chatgptapi`/`gpt4`/`claude`/`gemini`",
)
parser.add_argument(
"--block_size",
Expand All @@ -316,6 +316,12 @@ def main():
action="store_true",
help="Use pre-generated batch translations to create files. Run with --batch first before using this option",
)
parser.add_argument(
"--interval",
type=float,
default=0.01,
help="Request interval in seconds (e.g., 0.1 for 100ms). Currently only supported for Gemini models. Default: 0.01",
)

options = parser.parse_args()

Expand Down Expand Up @@ -366,7 +372,7 @@ def main():
API_KEY = options.custom_api or env.get("BBM_CUSTOM_API")
if not API_KEY:
raise Exception("Please provide custom translate api")
elif options.model == "gemini":
elif options.model in ["gemini", "geminipro"]:
API_KEY = options.gemini_key or env.get("BBM_GOOGLE_GEMINI_KEY")
elif options.model == "groq":
API_KEY = options.groq_key or env.get("BBM_GROQ_API_KEY")
Expand Down Expand Up @@ -481,6 +487,16 @@ def main():
if options.batch_use_flag:
e.batch_use_flag = options.batch_use_flag

if options.model in ("gemini", "geminipro"):
e.translate_model.set_interval(options.interval)
if options.model == "gemini":
if options.model_list:
e.translate_model.set_model_list(options.model_list.split(","))
else:
e.translate_model.set_geminiflash_models()
if options.model == "geminipro":
e.translate_model.set_geminipro_models()

e.make_bilingual_book()


Expand Down
1 change: 1 addition & 0 deletions book_maker/translator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"deeplfree": DeepLFree,
"claude": Claude,
"gemini": Gemini,
"geminipro": Gemini,
"groq": GroqClient,
"tencentransmart": TencentTranSmart,
"customapi": CustomAPI,
Expand Down
180 changes: 139 additions & 41 deletions book_maker/translator/gemini_translator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import re
import time
from os import environ
from itertools import cycle

import google.generativeai as genai
from google.generativeai.types.generation_types import (
Expand All @@ -11,23 +13,36 @@
from .base_translator import Base

generation_config = {
"temperature": 0.7,
"temperature": 1.0,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
"max_output_tokens": 8192,
}

safety_settings = [
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_MEDIUM_AND_ABOVE"},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_MEDIUM_AND_ABOVE",
},
safety_settings = {
"HATE": "BLOCK_NONE",
"HARASSMENT": "BLOCK_NONE",
"SEXUAL": "BLOCK_NONE",
"DANGEROUS": "BLOCK_NONE",
}

PROMPT_ENV_MAP = {
"user": "BBM_GEMINIAPI_USER_MSG_TEMPLATE",
"system": "BBM_GEMINIAPI_SYS_MSG",
}

GEMINIPRO_MODEL_LIST = [
"gemini-1.5-pro",
"gemini-1.5-pro-latest",
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
]

GEMINIFLASH_MODEL_LIST = [
"gemini-1.5-flash",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-001",
"gemini-1.5-flash-002",
]


Expand All @@ -38,20 +53,57 @@ class Gemini(Base):

DEFAULT_PROMPT = "Please help me to translate,`{text}` to {language}, please return only translated content not include the origin text"

def __init__(self, key, language, **kwargs) -> None:
genai.configure(api_key=key)
def __init__(
self,
key,
language,
prompt_template=None,
prompt_sys_msg=None,
context_flag=False,
temperature=1.0,
**kwargs,
) -> None:
super().__init__(key, language)
self.context_flag = context_flag
self.prompt = (
prompt_template
or environ.get(PROMPT_ENV_MAP["user"])
or self.DEFAULT_PROMPT
)
self.prompt_sys_msg = (
prompt_sys_msg
or environ.get(PROMPT_ENV_MAP["system"])
or None # Allow None, but not empty string
)

genai.configure(api_key=next(self.keys))
generation_config["temperature"] = temperature

def create_convo(self):
model = genai.GenerativeModel(
model_name="gemini-pro",
model_name=self.model,
generation_config=generation_config,
safety_settings=safety_settings,
system_instruction=self.prompt_sys_msg,
)
self.convo = model.start_chat()
# print(model) # Uncomment to debug and inspect the model details.

def rotate_model(self):
self.model = next(self.model_list)
self.create_convo()
print(f"Using model {self.model}")

def rotate_key(self):
pass
genai.configure(api_key=next(self.keys))
self.create_convo()

def translate(self, text):
delay = 1
exponential_base = 2
attempt_count = 0
max_attempts = 7

t_text = ""
print(text)
# same for caiyun translate src issue #279 gemini for #374
Expand All @@ -60,32 +112,78 @@ def translate(self, text):
if len(text_list) > 1:
if text_list[0].isdigit():
num = text_list[0]
try:
self.convo.send_message(
self.DEFAULT_PROMPT.format(text=text, language=self.language)
)
print(text)
t_text = self.convo.last.text.strip()
except StopCandidateException as e:
match = re.search(r'content\s*{\s*parts\s*{\s*text:\s*"([^"]+)"', str(e))
if match:
t_text = match.group(1)
t_text = re.sub(r"\\n", "\n", t_text)
else:
t_text = "Can not translate"
except BlockedPromptException as e:
print(str(e))
t_text = "Can not translate by SAFETY reason.(因安全问题不能翻译)"
except Exception as e:
print(str(e))
t_text = "Can not translate by other reason.(因安全问题不能翻译)"

if len(self.convo.history) > 10:
self.convo.history = self.convo.history[2:]

while attempt_count < max_attempts:
try:
self.convo.send_message(
self.prompt.format(text=text, language=self.language)
)
t_text = self.convo.last.text.strip()
break
except StopCandidateException as e:
print(
f"Translation failed due to StopCandidateException: {e} Attempting to switch model..."
)
self.rotate_model()
except BlockedPromptException as e:
print(
f"Translation failed due to BlockedPromptException: {e} Attempting to switch model..."
)
self.rotate_model()
except Exception as e:
print(
f"Translation failed due to {type(e).__name__}: {e} Will sleep {delay} seconds"
)
time.sleep(delay)
delay *= exponential_base

self.rotate_key()
if attempt_count >= 1:
self.rotate_model()

attempt_count += 1

if attempt_count == max_attempts:
print(f"Translation failed after {max_attempts} attempts.")
return

if self.context_flag:
if len(self.convo.history) > 10:
self.convo.history = self.convo.history[2:]
else:
self.convo.history = []

print("[bold green]" + re.sub("\n{3,}", "\n\n", t_text) + "[/bold green]")
# for limit
time.sleep(0.5)
# for rate limit(RPM)
time.sleep(self.interval)
if num:
t_text = str(num) + "\n" + t_text
return t_text

def set_interval(self, interval):
self.interval = interval

def set_geminipro_models(self):
self.set_models(GEMINIPRO_MODEL_LIST)

def set_geminiflash_models(self):
self.set_models(GEMINIFLASH_MODEL_LIST)

def set_models(self, allowed_models):
available_models = [
re.sub(r"^models/", "", i.name) for i in genai.list_models()
]
model_list = sorted(
list(set(available_models) & set(allowed_models)),
key=allowed_models.index,
)
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
self.rotate_model()

def set_model_list(self, model_list):
# keep the order of input
model_list = sorted(list(set(model_list)), key=model_list.index)
print(f"Using model list {model_list}")
self.model_list = cycle(model_list)
self.rotate_model()
12 changes: 6 additions & 6 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ exceptiongroup==1.2.1; python_version < "3.11"
filelock==3.14.0
frozenlist==1.4.1
fsspec==2024.3.1
google-ai-generativelanguage==0.6.4
google-api-core==2.19.0
google-api-python-client==2.127.0
google-auth==2.29.0
google-ai-generativelanguage==0.6.10
google-api-core==2.21.0
google-api-python-client==2.149.0
google-auth==2.35.0
google-auth-httplib2==0.2.0
google-generativeai==0.5.4
googleapis-common-protos==1.63.0
google-generativeai==0.8.3
googleapis-common-protos==1.65.0
groq==0.8.0
grpcio==1.63.0
grpcio-status==1.62.2
Expand Down

0 comments on commit 9261d92

Please sign in to comment.