Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/encoder decoder dq restructure #766

Closed
wants to merge 15 commits into from

Conversation

elboy3
Copy link
Contributor

@elboy3 elboy3 commented Oct 4, 2023

Please do not create a pull request without creating an issue first.

Changes need to be discussed before proceeding, pull requests submitted without linked issues may be rejected.

Please provide enough information so that others can review your pull request. You can skip this if you're fixing a typo – it happens.

  • I have added tests to tests to cover my changes.
  • I have updated docs/, if necessary.
  • I have updated the README.md, if necessary.

What existing issue does this pull request close?

Put closes #issue-number in this pull request's description to auto-close the issue that this fixes.

How are these changes tested?

This pull request includes automated tests for the code it touches and those tests are described below. If no tests are included, reasons why must be provided below.

These changes are tested with [...]

Demonstration

Demonstrate your contribution.

For example, what are the exact commands you ran and their output, related screenshots, screen-recordings, test runs, anything that can showcase.

Provide additional context.

Provide as much relevant context as you like.

@@ -14,7 +14,8 @@ class TaskType(str, Enum):
object_detection = "object_detection"
semantic_segmentation = "semantic_segmentation"
prompt_evaluation = "prompt_evaluation"
seq2seq = "seq2seq"
seq2seq = "seq2seq" # TODO Remove
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we'll need to check with rodrigo if he uses "seq2seq" at all on the UI as well. and we'll have to make some api / runners / rungalileo changes too.

depending on how much seq2seq is hard coded elsewhere, we should either just rename it to encoder decoder OR use a new task type 10 as encoder decoder and just kinda deprecate 8 seq2seq

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay! Yeah would love some guidance here.

@@ -58,8 +58,12 @@ def token_map_key(self) -> str:
return self.inference_name
return str(self.split)

@abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think you need this decorator since it's mainly for the base class (definition in BaseGalileoModelLogger)

class EncoderDecoderModelLogger(Seq2SeqModelLogger):
# TODO Add in API so we can use encoder_decoder
# __logger_name__ = "encoder_decoder"
__logger_name__ = "seq2seq"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, we'll have to change this to encoder_decoder

tabular_classification,
text_classification,
text_multi_label,
text_ner,
)
from dataquality.loggers.model_logger.base_model_logger import BaseGalileoModelLogger
from dataquality.loggers.model_logger.seq2seq import encoder_decoder, seq2seq
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should rename the seq2seq file (not folder) to seq2seq_base

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay love that idea!

@@ -295,32 +284,31 @@ def separate_dataframe(
return BaseLoggerDataFrames(prob=prob, emb=emb, data=data_df)

@classmethod
@abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm pretty sure, the abstract method should only go in the base and should not be implemented, that's how i've always seen it used. lmk if you think otherwise we can search a bit deeper into recommended usage!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pasting in some thoughts from GPT-4

Copy link
Contributor

@jonathangomesselman jonathangomesselman Oct 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the topic of can you have partial implementations within an abstractmethod here are its thoughts:

GPT-4 talking here

Yes, you can provide a partial implementation in an abstract method and still require subclasses to override it. This can be useful when there's some shared logic that you want to include in the base class, but you still want to ensure that subclasses provide their own specific implementations.

Here's how you can do this:

  1. Define the abstract method in the base class with the @AbstractMethod decorator.
  2. Within the abstract method, provide the shared logic.
  3. In the subclasses, you can use the super() function to call the base class's implementation before or after adding the subclass-specific logic.
    Here's an example to illustrate:
from abc import ABC, abstractmethod

class Animal(ABC):

    @abstractmethod
    def speak(self):
        print("The animal makes a sound...")

class Dog(Animal):

    def speak(self):
        super().speak()  # Call the base class's implementation
        print("Woof!")

class Cat(Animal):

    def speak(self):
        super().speak()  # Call the base class's implementation
        print("Meow!")

In this example:

  • Animal is an abstract base class with an abstract method speak that has a partial implementation.
  • Both Dog and Cat subclasses override the speak method. They use super().speak() to call the base class's implementation and then add their own specific logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the topic of should abstractmethod just be in the base-class here were its thoughts:

GPT-4 talking here

It's perfectly acceptable and sometimes even desirable to have abstract methods in subclasses (i.e., in intermediate layers of a class hierarchy). This practice can be useful in various scenarios:

  1. Refined Abstraction: Sometimes, the base class is a very general abstraction, and intermediate subclasses refine this abstraction but are still too general to provide a concrete implementation for certain methods. These intermediate classes can introduce new abstract methods that are relevant to their refined abstraction.

  2. Hierarchical Design: As you develop a class hierarchy, you may realize that certain behaviors are shared among a subset of subclasses, but not all. In such cases, it's beneficial to introduce an intermediate abstract class that encapsulates these shared behaviors as abstract methods.

  3. Flexibility: Introducing abstract methods in intermediate classes provides flexibility. As the design evolves, you can ensure that specific subclasses adhere to certain interfaces without affecting unrelated subclasses.

Here's an example to illustrate:

from abc import ABC, abstractmethod

class Vehicle(ABC):
    @abstractmethod
    def move(self):
        pass

class LandVehicle(Vehicle):
    @abstractmethod
    def start_engine(self):
        pass

class Car(LandVehicle):
    def move(self):
        print("The car drives.")

    def start_engine(self):
        print("Car engine started.")

class Boat(Vehicle):
    def move(self):
        print("The boat sails.")

In this design, the intermediate abstract class LandVehicle provides a refined abstraction for vehicles that move on land. By introducing the start_engine abstract method, we ensure that any concrete subclass of LandVehicle implements this behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that what GPT-4 says makes sense to me. I particularly like the points about hierarchical design and Refined Abstraction. I think here for example, calculate_cutoffs is very specific to just Seq2Seq classes!

Comment on lines 30 to 36
if task_type == task_type.seq2seq: # TODO Change to encoder_decoder
return encoder_decoder_logger_config

# TODO Change to encoder_decoder
raise GalileoException(
"Galileo's seq2seq watch method is only supported for seq2seq"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we can just use the get current task type helpers, since they will have already initialized the project with dq.init and we will have the task type stored in the config file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look for other instances of where we do get_data_logger().logger_config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay 👌 yes this seems helpful!

Copy link
Contributor Author

@elboy3 elboy3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks great! We'll just need to discuss task type switch and how to best handle that without breaking things, then we're good to go

@codecov-commenter
Copy link

codecov-commenter commented Oct 13, 2023

Codecov Report

Merging #766 (43e4cab) into main (8f0f7c3) will decrease coverage by 0.01%.
Report is 1 commits behind head on main.
The diff coverage is 99.03%.

@@            Coverage Diff             @@
##             main     #766      +/-   ##
==========================================
- Coverage   87.72%   87.72%   -0.01%     
==========================================
  Files         184      187       +3     
  Lines       15097    15139      +42     
==========================================
+ Hits        13244    13280      +36     
- Misses       1853     1859       +6     
Files Coverage Δ
...ity/loggers/data_logger/seq2seq/encoder_decoder.py 100.00% <100.00%> (ø)
...uality/loggers/data_logger/seq2seq/seq2seq_base.py 68.21% <100.00%> (ø)
...y/loggers/logger_config/seq2seq/encoder_decoder.py 100.00% <100.00%> (ø)
...lity/loggers/logger_config/seq2seq/seq2seq_base.py 100.00% <ø> (ø)
...ty/loggers/model_logger/seq2seq/encoder_decoder.py 100.00% <100.00%> (ø)
...ality/loggers/model_logger/seq2seq/seq2seq_base.py 92.64% <100.00%> (ø)
dataquality/schemas/task_type.py 100.00% <100.00%> (ø)
tests/loggers/test_seq2seq.py 100.00% <100.00%> (ø)
tests/utils/test_seq2seq_offset.py 100.00% <100.00%> (ø)
tests/utils/test_seq2seq_utils.py 100.00% <100.00%> (ø)
... and 1 more

... and 2 files with indirect coverage changes

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more



@check_noop
def set_tokenizer(
tokenizer: PreTrainedTokenizerFast,
logger_config: Union[EncoderDecoderLoggerConfig],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wouldn't have logger config as a param for this cause this is technically a user facing fn and we wouldn't expect them to pass in a logger config

assert isinstance(
tokenizer, PreTrainedTokenizerFast
), "Tokenizer must be an instance of PreTrainedTokenizerFast"
assert getattr(tokenizer, "is_fast", False), "Tokenizer must be a fast tokenizer"
for attr in ["encode", "decode", "encode_plus", "padding_side"]:
assert hasattr(tokenizer, attr), f"Tokenizer must support `{attr}`"
seq2seq_logger_config.tokenizer = tokenizer
logger_config.tokenizer = tokenizer
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to get config we could call the get data logger config helper in this fn

assert isinstance(
model, PreTrainedModel
), "model must be an instance of transformers PreTrainedModel"
assert model.can_generate(), "model must contain a `generate` method for seq2seq"

set_tokenizer(tokenizer, max_input_tokens, max_target_tokens)
set_tokenizer(tokenizer, logger_config, max_input_tokens, max_target_tokens)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above!

"""

# TODO Change to encoder_decoder after updating API
__logger_name__ = "seq2seq" # encoder_decoder
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed it should be encoder_decoder

common data type validation.
"""
super().validate_and_format()
# TODO: question type checking does not work in super()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean? we can look into this together

@@ -96,7 +83,13 @@ def token_map_key(self) -> str:
return self.inference_name
return str(self.split)

@abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we keep the fn here but just remove the abstractmethod decorator

@@ -295,32 +277,31 @@ def separate_dataframe(
return BaseLoggerDataFrames(prob=prob, emb=emb, data=data_df)

@classmethod
@abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, i think let's just have a base classmethod that does some logic, and if the parent's want to call super() and do extra or want to override they can, but let's not mandate that the parent's have to override this fn

Comment on lines 5 to 7
# TODO Add comment
# This currently is purely a wrapper!
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO Add comment
# This currently is purely a wrapper!
pass
"""Logger config for Encoder Decoder
This logger currently has same fields as the base class
"""

something like this is fine ^ also you don't need the "pass"

Comment on lines +53 to +58
logprobs = self.convert_logits_to_logprobs(self.logits)
(
self.token_logprobs,
self.top_logprobs,
) = self.process_logprobs(
self.ids, logprobs # type: ignore
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is stuff that won't happen in docoder only?

@@ -58,8 +58,12 @@ def token_map_key(self) -> str:
return self.inference_name
return str(self.split)

@abstractmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👋 remove!

Comment on lines 42 to 43
8: TaskType.seq2seq, # TODO Remove
# 8: TaskType.encoder_decoder, # TODO add on API side
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm personally now thinking we should actually keep seq2seq and just call it deprecated, and then add 10 as encoder_decoder and 11 as decoder_only ..... it seems too hard to move everything from seq2seq to encoder_decoder across all repos

Copy link
Contributor Author

@elboy3 elboy3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally looking great! let's continue to pair

@elboy3
Copy link
Contributor Author

elboy3 commented Oct 17, 2023

Closing so @jonathangomesselman can finish cleaning comments and then create his own PR!

@elboy3 elboy3 closed this Oct 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants