Skip to content

Commit

Permalink
Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set (#1…
Browse files Browse the repository at this point in the history
…2493)

* Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set

* Format
  • Loading branch information
adrianeboyd committed Apr 3, 2023
1 parent 0ec4dc5 commit bbf232e
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 2 deletions.
9 changes: 8 additions & 1 deletion spacy/tests/serialize/test_serialize_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,13 @@ def test_serialize_doc_exclude(en_vocab):

def test_serialize_doc_span_groups(en_vocab):
doc = Doc(en_vocab, words=["hello", "world", "!"])
doc.spans["content"] = [doc[0:2]]
span = doc[0:2]
span.label_ = "test_serialize_doc_span_groups_label"
span.id_ = "test_serialize_doc_span_groups_id"
span.kb_id_ = "test_serialize_doc_span_groups_kb_id"
doc.spans["content"] = [span]
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
assert len(new_doc.spans["content"]) == 1
assert new_doc.spans["content"][0].label_ == "test_serialize_doc_span_groups_label"
assert new_doc.spans["content"][0].id_ == "test_serialize_doc_span_groups_id"
assert new_doc.spans["content"][0].kb_id_ == "test_serialize_doc_span_groups_kb_id"
9 changes: 8 additions & 1 deletion spacy/tests/serialize/test_serialize_docbin.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def test_serialize_doc_bin():
nlp = English()
for doc in nlp.pipe(texts):
doc.cats = cats
doc.spans["start"] = [doc[0:2]]
span = doc[0:2]
span.label_ = "UNUSUAL_SPAN_LABEL"
span.id_ = "UNUSUAL_SPAN_ID"
span.kb_id_ = "UNUSUAL_SPAN_KB_ID"
doc.spans["start"] = [span]
doc[0].norm_ = "UNUSUAL_TOKEN_NORM"
doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID"
doc_bin.add(doc)
Expand All @@ -63,6 +67,9 @@ def test_serialize_doc_bin():
assert doc.text == texts[i]
assert doc.cats == cats
assert len(doc.spans) == 1
assert doc.spans["start"][0].label_ == "UNUSUAL_SPAN_LABEL"
assert doc.spans["start"][0].id_ == "UNUSUAL_SPAN_ID"
assert doc.spans["start"][0].kb_id_ == "UNUSUAL_SPAN_KB_ID"
assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM"
assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID"

Expand Down
4 changes: 4 additions & 0 deletions spacy/tokens/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def add(self, doc: Doc) -> None:
for key, group in doc.spans.items():
for span in group:
self.strings.add(span.label_)
if span.kb_id in span.doc.vocab.strings:
self.strings.add(span.kb_id_)
if span.id in span.doc.vocab.strings:
self.strings.add(span.id_)

def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
"""Recover Doc objects from the annotations, using the given vocab.
Expand Down
4 changes: 4 additions & 0 deletions spacy/tokens/doc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1346,6 +1346,10 @@ cdef class Doc:
for group in self.spans.values():
for span in group:
strings.add(span.label_)
if span.kb_id in span.doc.vocab.strings:
strings.add(span.kb_id_)
if span.id in span.doc.vocab.strings:
strings.add(span.id_)
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope
Expand Down

0 comments on commit bbf232e

Please sign in to comment.