Skip to content

Commit

Permalink
dev(narugo): add state recover
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Aug 5, 2024
1 parent 3c0b943 commit 360c16a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 32 deletions.
49 changes: 43 additions & 6 deletions felinewhisker/ui/annotate.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import os.path
import pathlib
from typing import Iterable, Callable, Optional, Any

Expand All @@ -13,15 +15,50 @@


def create_annotation_tab(repo: DatasetRepository, demo: gr.Blocks,
datasource: Iterable[ImageItem], write_session: WriterSession,
fn_annotate_assist: Optional[Callable[[str], Any]] = None, **kwargs):
datasource: Iterable[ImageItem], write_session: WriterSession, state_file: str,
fn_annotate_assist: Optional[Callable[[str], Any]] = None,
**kwargs):
data_iterator = iter(datasource)

gr_state_output = gr.State(value=None)
gr_position_id = gr.State(value=-1)
gr_datasource_length = gr.State(value=None)
gr_max_length = gr.State(value=None)
gr_id_list = gr.State(value=[])

def _fn_state_save(position_id, max_length, id_list):
with open(state_file, 'w') as f:
json.dump({
'position_id': position_id,
'max_length': max_length,
'id_list': id_list,
}, f, indent=4, sort_keys=True)

gr_position_id.change(
fn=_fn_state_save,
inputs=[gr_position_id, gr_max_length, gr_id_list],
)
gr_max_length.change(
fn=_fn_state_save,
inputs=[gr_position_id, gr_max_length, gr_id_list],
)
gr_id_list.change(
fn=_fn_state_save,
inputs=[gr_position_id, gr_max_length, gr_id_list],
)

def _fn_load_state():
if os.path.exists(state_file):
with open(state_file, 'r') as f:
state = json.load(f)
return state['position_id'], state['max_length'], state['id_list']
else:
return -1, None, []

demo.load(
fn=_fn_load_state,
outputs=[gr_position_id, gr_max_length, gr_id_list],
)

with gr.Row():
gr_state_input = create_ui_for_annotator(
repo=repo,
Expand Down Expand Up @@ -87,8 +124,8 @@ def _fn_next(idx, ids, max_length):

gr_next.click(
fn=_fn_next,
inputs=[gr_position_id, gr_id_list, gr_datasource_length],
outputs=[gr_position_id, gr_id_list, gr_datasource_length],
inputs=[gr_position_id, gr_id_list, gr_max_length],
outputs=[gr_position_id, gr_id_list, gr_max_length],
)

def _ch_change(state):
Expand Down Expand Up @@ -121,7 +158,7 @@ def _fn_index_change(idx, ids, max_length):

gr_position_id.change(
fn=_fn_index_change,
inputs=[gr_position_id, gr_id_list, gr_datasource_length],
inputs=[gr_position_id, gr_id_list, gr_max_length],
outputs=[gr_state_input, gr_prev, gr_next, gr_save],
)

Expand Down
56 changes: 30 additions & 26 deletions felinewhisker/ui/dispatch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pathlib
from contextlib import contextmanager
from typing import Optional, ContextManager, Callable, Any

import gradio as gr
from hbutils.string import titleize
from hbutils.system import TemporaryDirectory

from .annotate import create_annotation_tab
from ..datasource import BaseDataSource
Expand All @@ -18,29 +20,31 @@ def create_annotator_app(
fn_annotate_assist: Optional[Callable[[str], Any]] = None,
annotation_options: Optional[dict] = None
) -> ContextManager[gr.Blocks]:
with repo.write(author=author) as write_session:
with datasource as source:
source.set_fn_contains_id(write_session.is_id_duplicated)

with gr.Blocks(css=_GLOBAL_CSS_CODE) as demo:
with gr.Row(elem_id='annotation_title'):
gr_title = gr.HTML(
f'<p class="title">'
f'<u>{titleize(repo.meta_info["task"])}</u> - {repo.meta_info["name"]}'
f'</p>'
)
_ = gr_title

with gr.Row():
with gr.Tabs():
with gr.Tab('Annotation'):
create_annotation_tab(
repo=repo,
demo=demo,
datasource=source,
write_session=write_session,
fn_annotate_assist=fn_annotate_assist,
**(annotation_options or {}),
)

yield demo
with TemporaryDirectory(prefix='felinewhisker_') as td_state, \
repo.write(author=author) as write_session, datasource as source:
state_file = os.path.join(td_state, 'state.json')
source.set_fn_contains_id(write_session.is_id_duplicated)

with gr.Blocks(css=_GLOBAL_CSS_CODE) as demo:
with gr.Row(elem_id='annotation_title'):
gr_title = gr.HTML(
f'<p class="title">'
f'<u>{titleize(repo.meta_info["task"])}</u> - {repo.meta_info["name"]}'
f'</p>'
)
_ = gr_title

with gr.Row():
with gr.Tabs():
with gr.Tab('Annotation'):
create_annotation_tab(
repo=repo,
demo=demo,
datasource=source,
write_session=write_session,
state_file=state_file,
fn_annotate_assist=fn_annotate_assist,
**(annotation_options or {}),
)

yield demo

0 comments on commit 360c16a

Please sign in to comment.