Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Good strategies for hierarchical classification with many classes #552

Open
miguelwon opened this issue Sep 15, 2024 · 5 comments
Open

Good strategies for hierarchical classification with many classes #552

miguelwon opened this issue Sep 15, 2024 · 5 comments

Comments

@miguelwon
Copy link

miguelwon commented Sep 15, 2024

I'm working in a hierarchical multi class problem, and if I flat the labels (flat approach) I have about 1193 classes, which perhaps can already be consider a extreme multi classification problem. Furthermore, per class I have less than 10 examples per unique class.

With so many classes, I can't go with pairs for all combination, because it will result in a huge amount of pairs and I'm a bit limit in hardware and time.

Also, since is hierarchical I think it would work better if I privilege pairs with examples with the same "father", because I want to have a good discrimination even between example within the same "father" category.

Do you know any good strategy to this kind of problem? Perhaps train first between some random picked high level hierarchy and then further training with pairs that share the same root?

@haukelicht
Copy link

I have a similar use case and was thinking about implementing the method proposed in "A Multi-task Approach to Neural Multi-label Hierarchical Patent Classification Using Transformers" (doi).

The paper authors provide a implementation using keras: https://github.com/boschresearch/hierarchical_patent_classification_ecir2021/blob/main/text_classification/model/THMM.py

You could adapt their code to torch and subclass the classification head of SetfitModel as described here: https://huggingface.co/docs/setfit/en/how_to/classification_heads#custom-differentiable-head

@miguelwon
Copy link
Author

Thanks @haukelicht for the reference. I'll have a look.

But this is what I have done. I built pairs from datapoints having the same common hierarchical "father". I did it to generate pairs somewhat related (they share the same high level class) but that I know they should be classified differently. These paris are like hard negatives, and make the task to distinguish them harder. Since I built the pairs from the combination of only examples with the same high level class, the final total number of pairs is significantly reduced.

Then, fine-tuned a retrieval model (I worked with gte-multilingual-base), followed to train a head with a simple NN.

With this approach I was able to achieved a good model evaluation.

@haukelicht
Copy link

Sounds great, @miguelwon! Can you maybe point me to the class or method you changed/subclassed to change how setfit constructs the pairwise data?

@miguelwon
Copy link
Author

miguelwon commented Oct 11, 2024

I didn't use setfit. Since I want such custom setup I did code myself. Is a bit of a mess but I will copy it here just for you to have an idea.

Suppose you have a list of dicts in main_train, where the value of "title" contains the full hierarchy:

main_train_dict = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for row in main_train:
    full_title = row['title']

    chapter,section,form_title = full_title.split("|")
    chapter = chapter.strip()
    section = section.strip()
    form_title = form_title.strip()  
    text = row['text']

    main_train_dict[chapter][section][form_title].append(text)

then to build the pairs I have the following code:

def create_pairs(main_train_dict, same_section_ratio=1.0, other_chapter_ratio=0.3):
    positive_pairs = []
    negative_pairs = []

    for chapter in main_train_dict:
        for section in main_train_dict[chapter]:
            section_texts = []
            for form_title in main_train_dict[chapter][section]:
                texts = main_train_dict[chapter][section][form_title]
                
                # Create positive pairs within the same form_title
                for pair in combinations(texts, 2):
                    positive_pairs.append((pair[0], pair[1], 1))
                
                section_texts.extend([(text, form_title) for text in texts])
            
            # Create negative pairs within the same section
            for (text1, title1), (text2, title2) in combinations(section_texts, 2):
                if title1 != title2:
                    negative_pairs.append((text1, text2, 0))

    # Create negative pairs from other chapters
    all_texts = [(text, chapter, section, form_title) 
                 for chapter in main_train_dict 
                 for section in main_train_dict[chapter] 
                 for form_title in main_train_dict[chapter][section] 
                 for text in main_train_dict[chapter][section][form_title]]

    other_chapter_negatives = []
    for (text1, ch1, sec1, _), (text2, ch2, sec2, _) in combinations(all_texts, 2):
        if ch1 != ch2:
            other_chapter_negatives.append((text1, text2, 0))

    # Balance the dataset
    total_pairs = len(positive_pairs)
    num_same_section = int(total_pairs * same_section_ratio)
    num_other_chapter = int(total_pairs * other_chapter_ratio)

    negative_pairs = random.sample(negative_pairs, num_same_section)
    other_chapter_negatives = random.sample(other_chapter_negatives, num_other_chapter)

    all_pairs = positive_pairs + negative_pairs + other_chapter_negatives
    random.shuffle(all_pairs)

    return all_pairs

# Create the pairs
pairs = create_pairs(main_train_dict)

# Print some statistics
positive_count = sum(1 for _, _, label in pairs if label == 1)
negative_count = sum(1 for _, _, label in pairs if label == 0)

print(f"Total pairs: {len(pairs)}")
print(f"Positive pairs: {positive_count}")
print(f"Negative pairs: {negative_count}")


Do the same for the test set and then

# Prepare train data
train_examples = [InputExample(texts=[text1, text2], label=float(label)) for text1, text2, label in pairs]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)

# Prepare test data
test_examples = [InputExample(texts=[text1, text2], label=float(label)) for text1, text2, label in test_pairs]
test_evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(test_examples, name='test_evaluation')

And train with:

# Initialize the model
model = SentenceTransformer('Alibaba-NLP/gte-multilingual-base',trust_remote_code=True)

# Define the loss
train_loss = losses.CosineSimilarityLoss(model)

# Train the model
num_epochs = 3
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)

model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=test_evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path='./results')

So, then after this you have a gte finetuned for your classes. Then, you can easily use sklearn for example to train a NN or a logistic regression for the gte embeddings.

@philmas
Copy link

philmas commented Jan 4, 2025

Hai @miguelwon,

Thanks for this post. It is super relevant.

I am trying something similar, however, I seem to be unable to successfully finetune gte to my classes (5k classes with each about 32 examples). Could you perhaps share some more details on how much training time, performance and such?

I have the issue that even with that many classes it becomes too resource intensive and difficult to train. Also my evaluation lossess increase or my training loss dont decrease. Is there a reason you use a batch size of 16?

I have little experience and seem unable to find some good sources on this issue. Including why cosinesimilarity is an appropriate loss here as well (relative to alternatives).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants