Skip to content

Commit

Permalink
fix: support ocr specific models
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Apr 24, 2024
1 parent c019962 commit 1a1bce7
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 31 deletions.
102 changes: 76 additions & 26 deletions config.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,87 @@
from os import environ

CORS_ALLOW_ORIGINS = ["*"] # CORS Allow Origins
if environ.get("CORS_ALLOW_ORIGINS") and len(environ.get("CORS_ALLOW_ORIGINS")) > 0:
CORS_ALLOW_ORIGINS = environ.get("CORS_ALLOW_ORIGINS").split(",")

PDF_MAX_IMAGES = int(environ.get("PDF_MAX_IMAGES", 0)) # The maximum number of images to extract from a PDF file
def to_str(key: str, default: str = "") -> str:
"""Converts string to string."""

AZURE_SPEECH_KEY = environ.get("AZURE_SPEECH_KEY") # Azure Speech Key
AZURE_SPEECH_REGION = environ.get("AZURE_SPEECH_REGION") # e.g. "eastus"
ENABLE_AZURE_SPEECH = AZURE_SPEECH_KEY and AZURE_SPEECH_REGION
value = environ.get(key, default)
return value.strip()

MAX_FILE_SIZE = float(environ.get("MAX_FILE_SIZE", -1)) # Max File Size (unit: MiB)

STORAGE_TYPE = environ.get("STORAGE_TYPE", "common").lower() # Storage Type
LOCAL_STORAGE_DOMAIN = environ.get("LOCAL_STORAGE_DOMAIN", "").rstrip("/") # Local Storage Domain
def to_none_str(key: str, default: str = None) -> str:
"""Converts string to string."""

S3_BUCKET = environ.get("S3_BUCKET", "") # S3 Bucket
S3_ACCESS_KEY = environ.get("S3_ACCESS_KEY", "") # S3 Access Key
S3_SECRET_KEY = environ.get("S3_SECRET_KEY", "") # S3 Secret Key
S3_REGION = environ.get("S3_REGION", "") # S3 Region
S3_DOMAIN = environ.get("S3_DOMAIN", "").rstrip("/") # S3 Domain (Optional)
S3_DIRECT_URL_DOMAIN = environ.get("S3_DIRECT_URL_DOMAIN", "").rstrip("/") # S3 Direct/Proxy URL Domain (Optional)
S3_SIGN_VERSION = environ.get("S3_SIGN_VERSION", None) # S3 Sign Version
value = environ.get(key, default)
return value.strip() if value else None

S3_API = S3_DOMAIN or f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com" # S3 API
S3_SPACE = S3_DIRECT_URL_DOMAIN or S3_API # S3 Image URL Domain

TG_ENDPOINT = environ.get("TG_ENDPOINT", "").rstrip("/") # Telegram Endpoint
TG_PASSWORD = environ.get("TG_PASSWORD", "") # Telegram Password
def to_endpoint(key: str, default: str = "") -> str:
"""Converts string to string."""
return to_str(key, default).rstrip("/")


def to_list(key: str, default: list) -> list:
"""Converts comma-separated string to list."""
key = to_str(key, "")
if not key:
return default

return [item for item in key.split(",") if item]


def to_bool(key: str, default: bool) -> bool:
"""Converts string to bool."""
value = to_str(key, "")
if not value:
return default

return value.lower() == "true" or value == "1"


TG_API = TG_ENDPOINT + "/api" + (f"?pass={TG_PASSWORD}" if TG_PASSWORD and len(TG_PASSWORD) > 0 else "")
def to_float(key: str, default: float) -> float:
"""Converts string to float."""
value = to_str(key, "")
if not value:
return default

return float(value)


def to_int(value: str, default: int) -> int:
"""Converts string to int."""
value = to_str(value, "")
if not value:
return default

return int(value)


# General Config
CORS_ALLOW_ORIGINS = to_list("CORS_ALLOW_ORIGINS", ["*"]) # CORS Allow Origins
MAX_FILE_SIZE = to_float("MAX_FILE_SIZE", -1) # Max File Size
PDF_MAX_IMAGES = to_int("PDF_MAX_IMAGES", 10) # PDF Max Images
AZURE_SPEECH_KEY = to_str("AZURE_SPEECH_KEY") # Azure Speech Key
AZURE_SPEECH_REGION = to_str("AZURE_SPEECH_REGION") # Azure Speech Region
ENABLE_AZURE_SPEECH = AZURE_SPEECH_KEY and AZURE_SPEECH_REGION # Enable Azure Speech

# Storage Config
STORAGE_TYPE = to_str("STORAGE_TYPE", "common") # Storage Type
LOCAL_STORAGE_DOMAIN = to_str("LOCAL_STORAGE_DOMAIN", "").rstrip("/") # Local Storage Domain
S3_BUCKET = to_str("S3_BUCKET", "") # S3 Bucket
S3_ACCESS_KEY = to_str("S3_ACCESS_KEY", "") # S3 Access Key
S3_SECRET_KEY = to_str("S3_SECRET_KEY", "") # S3 Secret Key
S3_REGION = to_str("S3_REGION", "") # S3 Region
S3_DOMAIN = to_endpoint("S3_DOMAIN", "") # S3 Domain (Optional)
S3_DIRECT_URL_DOMAIN = to_endpoint("S3_DIRECT_URL_DOMAIN", "") # S3 Direct/Proxy URL Domain (Optional)
S3_SIGN_VERSION = to_none_str("S3_SIGN_VERSION") # S3 Sign Version
S3_API = S3_DOMAIN or f"https://{S3_BUCKET}.s3.{S3_REGION}.amazonaws.com" # S3 API
S3_SPACE = S3_DIRECT_URL_DOMAIN or S3_API # S3 Image URL Domain
TG_ENDPOINT = to_endpoint("TG_ENDPOINT", "") # Telegram Endpoint
TG_PASSWORD = to_str("TG_PASSWORD", "") # Telegram Password
TG_API = TG_ENDPOINT + "/api" + (f"?pass={TG_PASSWORD}" if TG_PASSWORD and len(TG_PASSWORD) > 0 else "") # Telegram API

OCR_ENDPOINT = environ.get("OCR_ENDPOINT", "").rstrip("/") # OCR Endpoint
OCR_ENABLED = int(environ.get("OCR_ENABLED", 0)) == 1 # OCR Enabled
OCR_SKIP_MODELS = environ.get("OCR_SKIP_MODELS", "").split(",") # OCR Skip Models
OCR_SPEC_MODELS = environ.get("OCR_SPEC_MODELS", "").split(",") # OCR Specific Models
# OCR Config
OCR_ENDPOINT = to_endpoint("OCR_ENDPOINT", "") # OCR Endpoint
OCR_ENABLED = to_bool("OCR_ENABLED", False) # OCR Enabled
OCR_SKIP_MODELS = to_list("OCR_SKIP_MODELS", []) # OCR Skip Models
OCR_SPEC_MODELS = to_list("OCR_SPEC_MODELS", []) # OCR Specific Models
9 changes: 5 additions & 4 deletions handlers/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,13 @@ def could_enable_ocr(model: str = "") -> bool:
# if model is not defined
return True

if len(OCR_SPEC_MODELS) > 0:
if contains(model, OCR_SPEC_MODELS):
# if model is in specific list
return True

if contains(model, OCR_SKIP_MODELS):
# if model is in skip list
return False

if len(OCR_SPEC_MODELS) > 0 and not contains(model, OCR_SPEC_MODELS):
# if specific models are defined and model is not in the list
return False

return True
2 changes: 1 addition & 1 deletion util.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ def md5_encode(string) -> str:
def contains(value: str, items: List[str]) -> bool:
"""Returns True if value is in items or contains it."""

return any(item in value for item in items)
return any(item in value for item in items if item)

0 comments on commit 1a1bce7

Please sign in to comment.