Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Summarizer #38

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ services:
extends: env_django
command: >
bash -c "
# python manage.py collectstatic --no-input &&
python manage.py makemigrations &&
python manage.py migrate &&
gunicorn obstracts.wsgi:application --reload --bind 0.0.0.0:8001
Expand Down
32 changes: 21 additions & 11 deletions obstracts/cjob/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing

from dogesec_commons.stixifier.stixifier import StixifyProcessor, ReportProperties
from dogesec_commons.stixifier.summarizer import parse_summarizer_model
from ..server.models import Job, FeedProfile
from ..server import models

Expand Down Expand Up @@ -78,27 +79,27 @@ def poll_job(job_id):
current_task.retry(max_retries=200)


def new_task(feed_dict, profile_id):
def new_task(feed_dict, profile_id, summary_provider):
kwargs = dict(id=feed_dict["feed_id"], profile_id=profile_id)
if title := feed_dict.get("title"):
kwargs.update(title=title)
feed, _ = FeedProfile.objects.update_or_create(defaults=kwargs, id=feed_dict["feed_id"])
job = Job.objects.create(id=feed_dict["job_id"], feed=feed, profile_id=profile_id)
(poll_job.s(job.id) | start_processing.s(job.id)).apply_async(
(poll_job.s(job.id) | start_processing.s(job.id, summary_provider)).apply_async(
countdown=5, root_id=job.id, task_id=job.id
)
return job

def new_post_patch_task(input_dict, profile_id):
def new_post_patch_task(input_dict, profile_id, summary_provider):
job = Job.objects.create(id=input_dict["job_id"], feed_id=input_dict["feed_id"], profile_id=profile_id)
(poll_job.s(job.id) | start_processing.s(job.id)).apply_async(
(poll_job.s(job.id) | start_processing.s(job.id, summary_provider)).apply_async(
countdown=5, root_id=job.id, task_id=job.id
)
return job


@shared_task
def start_processing(h4f_job, job_id):
def start_processing(h4f_job, job_id, summary_provider):
job = Job.objects.get(id=job_id)
logging.info(
f"processing {job_id=}, {job.feed_id=}, {current_task.request.root_id=}"
Expand Down Expand Up @@ -128,7 +129,7 @@ def start_processing(h4f_job, job_id):
)
break
logging.info("processing %d posts for job %s", len(posts), job_id)
tasks = [process_post.si(job_id, post) for post in posts]
tasks = [process_post.si(job_id, post, summary_provider) for post in posts]
tasks.append(job_completed_with_error.si(job_id))
return chain(tasks).apply_async()

Expand All @@ -142,13 +143,13 @@ def set_job_completed(job_id):


@shared_task
def process_post(job_id, post, *args):
def process_post(job_id, post, summary_provider, *args):
job = Job.objects.get(id=job_id)
post_id = str(post['id'])
try:
file = io.BytesIO(post['description'].encode())
file.name = f"post-{post_id}.html"
processor = StixifyProcessor(file, job.profile, job_id=job.id, file2txt_mode="html_article", report_id=post_id, base_url=post['link'])
stream = io.BytesIO(post['description'].encode())
stream.name = f"post-{post_id}.html"
processor = StixifyProcessor(stream, job.profile, job_id=job.id, file2txt_mode="html_article", report_id=post_id, base_url=post['link'])
processor.collection_name = job.feed.collection_name
properties = ReportProperties(
name=post['title'],
Expand All @@ -165,8 +166,17 @@ def process_post(job_id, post, *args):
)
processor.setup(properties, dict(_obstracts_feed_id=str(job.feed.id), _obstracts_post_id=post_id))
processor.process()

file, _ = models.File.objects.get_or_create(post_id=post_id)
if summary_provider:
logging.info(f"summarizing report {processor.report_id} using `{summary_provider}`")
try:
summary_provider = parse_summarizer_model(summary_provider)
file.summary = summary_provider.summarize(processor.output_md)
except BaseException as e:
print(f"got err {e}")
logging.info(f"got err {e}", exc_info=True)



file.markdown_file.save('markdown.md', processor.md_file.open(), save=True)
models.FileImage.objects.filter(report=file).delete() # remove old references
Expand Down
1 change: 1 addition & 0 deletions obstracts/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def upload_to_func(instance: 'File', filename):
class File(models.Model):
post_id = models.UUIDField(primary_key=True)
markdown_file = models.FileField(upload_to=upload_to_func, null=True)
summary = models.CharField(max_length=65535, null=True)


class FileImage(models.Model):
Expand Down
15 changes: 7 additions & 8 deletions obstracts/server/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .models import Profile, Job, FileImage
from drf_spectacular.utils import extend_schema_serializer, extend_schema_field
from django.utils.translation import gettext_lazy as _
from dogesec_commons.stixifier.summarizer import parse_summarizer_model


class JobSerializer(serializers.ModelSerializer):
Expand All @@ -12,25 +13,23 @@ class Meta:
# fields = "__all__"
exclude = ["feed", "profile"]


class FeedSerializer(serializers.Serializer):
class CreateTaskSerializer(serializers.Serializer):
profile_id = serializers.PrimaryKeyRelatedField(queryset=Profile.objects, error_messages={
'required': _('This field is required.'),
'does_not_exist': _('Invalid profile with id "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected profile id (uuid), received {data_type}.'),
})
ai_summary_provider = serializers.CharField(allow_blank=True, allow_null=True, validators=[parse_summarizer_model], default=None, write_only=True, help_text="AI Summary provider int the format provider:model e.g `openai:gpt-3.5-turbo`")

class FeedSerializer(CreateTaskSerializer):
url = serializers.URLField(help_text="The URL of the RSS or ATOM feed")
include_remote_blogs = serializers.BooleanField(help_text="", default=False, required=False)

class PatchFeedSerializer(FeedSerializer):
url = None

class PatchPostSerializer(serializers.Serializer):
profile_id = serializers.PrimaryKeyRelatedField(queryset=Profile.objects, error_messages={
'required': _('This field is required.'),
'does_not_exist': _('Invalid profile with id "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected profile id (uuid), received {data_type}.'),
})
class PatchPostSerializer(CreateTaskSerializer):
pass

class PostCreateSerializer(PatchPostSerializer):
title = serializers.CharField()
Expand Down
44 changes: 34 additions & 10 deletions obstracts/server/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import io
import json
import logging
from urllib.parse import urljoin
from django.http import HttpResponse, FileResponse
from django.shortcuts import get_object_or_404
from rest_framework import viewsets, decorators, exceptions, status
from drf_spectacular.utils import OpenApiParameter
from rest_framework import viewsets, decorators, exceptions, status, renderers
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse
from drf_spectacular.types import OpenApiTypes
from .import autoschema as api_schema
from dogesec_commons.objects.helpers import OBJECT_TYPES
Expand Down Expand Up @@ -37,6 +38,11 @@
import mistune
from mistune.renderers.markdown import MarkdownRenderer
from mistune.util import unescape

class PlainMarkdownRenderer(renderers.BaseRenderer):
media_type = "text/markdown"
format = "text/markdown"

class MarkdownImageReplacer(MarkdownRenderer):
def __init__(self, request, queryset):
self.request = request
Expand All @@ -55,7 +61,10 @@ def codespan(self, token: dict[str, dict], state: mistune.BlockState) -> str:
token['raw'] = unescape(token['raw'])
return super().codespan(token, state)


@classmethod
def get_markdown(cls, request, md_text, images_qs: 'models.models.BaseManager[models.FileImage]'):
modify_links = mistune.create_markdown(escape=False, renderer=cls(request, images_qs))
return modify_links(md_text)

@extend_schema_view(
list=extend_schema(
Expand Down Expand Up @@ -201,12 +210,14 @@ def make_request(cls, request, path, request_body=None):
)

def create(self, request, *args, **kwargs):
profile_id = self.parse_profile(request)

s = serializers.FeedSerializer(data=request.data)
s.is_valid(raise_exception=True)
resp = self.make_request(request, "/api/v1/feeds/")
if resp.status_code == 201:
out = json.loads(resp.content)
out['feed_id'] = out['id']
job = tasks.new_task(out, profile_id)
job = tasks.new_task(out, s.data['profile_id'], s.data['ai_summary_provider'])
return Response(JobSerializer(job).data, status=status.HTTP_201_CREATED)
return resp

Expand Down Expand Up @@ -260,7 +271,7 @@ def partial_update(self, request, *args, **kwargs):
if resp.status_code == 201:
out = json.loads(resp.content)
out['feed_id'] = out['id']
job = tasks.new_task(out, s.data.get("profile_id", feed.profile.id))
job = tasks.new_task(out, s.data.get("profile_id", feed.profile.id), s.data['ai_summary_provider'])
return Response(JobSerializer(job).data, status=status.HTTP_201_CREATED)
return resp

Expand Down Expand Up @@ -362,7 +373,7 @@ def partial_update(self, request, *args, **kwargs):
self.remove_report(post_id, feed.collection_name)
out = json.loads(resp.content)
out['job_id'] = out['id']
job = tasks.new_post_patch_task(out, s.data.get("profile_id", feed.profile.id))
job = tasks.new_post_patch_task(out, s.data.get("profile_id", feed.profile.id), s.data['ai_summary_provider'])
return Response(JobSerializer(job).data, status=status.HTTP_201_CREATED)
return resp

Expand All @@ -379,7 +390,7 @@ def create(self, request, *args, **kwargs):
if resp.status_code == 201:
out = json.loads(resp.content)
out['job_id'] = out['id']
job = tasks.new_post_patch_task(out, s.data.get("profile_id", feed.profile.id))
job = tasks.new_post_patch_task(out, s.data.get("profile_id", feed.profile.id), s.data['ai_summary_provider'])
return Response(JobSerializer(job).data, status=status.HTTP_201_CREATED)
return resp

Expand Down Expand Up @@ -446,8 +457,8 @@ def get_post_objects(self, post_id, feed_id):
@decorators.action(detail=True, methods=["GET"])
def markdown(self, request, feed_id=None, post_id=None):
obj = get_object_or_404(models.File, post_id=post_id)
modify_links = mistune.create_markdown(escape=False, renderer=MarkdownImageReplacer(self.request, models.FileImage.objects.filter(report__post_id=post_id)))
return FileResponse(streaming_content=modify_links(obj.markdown_file.read().decode()), content_type='text/markdown', filename='markdown.md')
resp_text = MarkdownImageReplacer.get_markdown(request, obj.markdown_file.read().decode(), models.FileImage.objects.filter(report__post_id=post_id))
return FileResponse(streaming_content=resp_text, content_type='text/markdown', filename='markdown.md')

@extend_schema(
responses={200: serializers.ImageSerializer(many=True), 404: api_schema.DEFAULT_404_ERROR, 400: api_schema.DEFAULT_400_ERROR},
Expand Down Expand Up @@ -485,6 +496,19 @@ def remove_report(self, post_id, collection):
helper.execute_query(query, bind_vars={"@collection": f"{collection}_{c}", 'post_id': post_id}, paginate=False)


@extend_schema(
responses=None,
description="Get the summary of the Post",
summary="Get the summary of the post if `ai_summary_provider` was enabled.",
)
@decorators.action(methods=["GET"], detail=True)
def summary(self, request, feed_id=None, post_id=None):
obj = get_object_or_404(models.File, post_id=post_id)
if not obj.summary:
raise exceptions.NotFound(f"No Summary for post")
return FileResponse(streaming_content=io.BytesIO(obj.summary.encode()), content_type='text/markdown', filename='summary.md')


@extend_schema_view(
list=extend_schema(
summary="Search Jobs",
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,5 @@ django-storages[s3]==1.14.4
stix2arango @ https://github.com/muchdogesec/stix2arango/archive/main.zip
file2txt @ https://github.com/muchdogesec/file2txt/archive/main.zip
txt2stix @ https://github.com/muchdogesec/txt2stix/releases/download/main-2024-11-13/txt2stix-0.0.1b5-py3-none-any.whl
dogesec_commons[stixifier] @ https://github.com/muchdogesec/dogesec_commons/releases/download/main-2024-11-13/dogesec_commons-0.0.1b2-py3-none-any.whl
dogesec_commons[stixifier] @ https://github.com/muchdogesec/dogesec_commons/releases/download/summarizer-2024-11-13/dogesec_commons-0.0.1b2-py3-none-any.whl