Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can MatchFinder/InferenceModel return distances for all classes? #718

Open
vemchance opened this issue Oct 13, 2024 · 3 comments
Open

Can MatchFinder/InferenceModel return distances for all classes? #718

vemchance opened this issue Oct 13, 2024 · 3 comments

Comments

@vemchance
Copy link

vemchance commented Oct 13, 2024

As always, thanks again for the library and apologies for asking another question!

I was wondering if the MatchFinder or InferenceModel can return all of the distances for all classes or if it is expected that it will only return distances of 'matches'. I am using a late fusion technique which takes the distance of a query image to all classes and uses dynamic weighting to produce a more accurate prediction. I train two seperate modalities, one which uses the text and one only the images, to produce two different models, and use the distance of both to dynamically weight each modality for a query image.

Using class centers in this discussion, the dynamic weighting works extremely well and I can use the InferenceModel to return all the neighbours (e.g., a distance for the query image to all classes in the dataset). However, the accuracy of the individual modalities is different from in training (it is using the same dataset), though I expect this behaviour since I did not use the same class centre method in training. This code below returns distances for all classes:

class_centers = []
for L in torch.unique(labels, sorted=True):
    mask = labels == L
    i, l = np.where(mask.cpu())
    class_embeddings = embeddings[i, ...]
    center = torch.mean(class_embeddings, dim=0)
    class_centers.append(center)

class_centers = torch.stack(class_centers, dim=0)
knn = FaissKNN(reset_before=False, reset_after=False)
knn.train(class_centers)
inference_model = InferenceModel(model, knn_func=knn, data_device=device)

dataloader = torch.utils.data.DataLoader(val_dataset, batch_size = 1)

for data, labels in dataloader:
    data, labels = move_data_to_device(data, labels, device=device)
    distances, indices = inference_model.get_nearest_neighbors(data, k=classes)

MatchFinder will provide me with the same accuracy as in training. However, it will only return a certain number of classes and their distances, rather than the distances between the query and all classes, even if I set the number of neighbours to the total number of classes, modify the MatchFinder threshold, or remove the MatchFinder entirely. I'm not sure if this is expected behaviour but wanted to check or perhaps I'm misunderstanding the InferenceModel?

For example:

knn = FaissKNN(reset_before=False, reset_after=False)
match_finder = MatchFinder(distance=CosineSimilarity(), threshold=0.8) # modifying threshold doesn't return all classes
inference_model = InferenceModel(model, data_device=device, knn_func=knn, match_finder=match_finder) # removing matchfinder doesn't result all classes
inference_model.train_knn(train_dataset)

dataloader = torch.utils.data.DataLoader(val_dataset,  batch_size = 1) 

for data, labels in dataloader:
    data, labels = move_data_to_device(data, labels, device=device)

    # Get distances and indices for the nearest classes
    distances, indices = inference_model.get_nearest_neighbors(data, k=classes) # k is set to the total number of classes

The above code will return only a few distances (32 as opposed to 213, which is the total number of classes) even when adjusting or removing the MatchFinder. Usually, I get between 20 - 32 distances rather than 213.

I would like to use the InferenceModel as it is above instead of calculating the class centers, as I can compare my fusion method to my original results. It seems the InferenceModel method gives comparable results to training which I expected, but the fusion method requires all classes to have a distance calculated which I can't seem to retrieve even when increasing the number of neighbours.

@vemchance
Copy link
Author

Apologies, after some more work, I've realised I have misunderstood the InferenceModel and that it is returning the indices of a
nearest image, not the classes, so of course most classes aren't covered by the indexes returned. In that case, is there a method to retrieve the nearest class instead of image indices, aside from the one mentioned here?

@KevinMusgrave
Copy link
Owner

KevinMusgrave commented Oct 15, 2024

MatchFinder is only used in the get_matches and is_match functions of InferenceModel. That's why it's having no effect on get_nearest_neighbors. This should be improved, or explained more clearly somewhere. I think it has caused confusion for someone else before too. Sorry about that.

If I understand correctly, your first snippet of code enables you to find the k nearest class centers, but you want to find the nearest classes based on individual samples rather than class centers. I don't think there is a built-in way to achieve this. The only way I can think of is to loop through each class of the dataset and get the distance of the nearest sample for that subset. Something like this:

class_distances = {}
for label in unique_labels:
    curr_subset = dataset[labels == label]
    subset_embeddings = ...
    knn = FaissKNN(reset_before=False, reset_after=False)
    knn.train(subset_embeddings)
    distances, _ = inference_model.get_nearest_neighbors(data, k=1)
    class_distances[label] = distances[0][0]

# closest class is the key in class_distances with the smallest distance

As for this issue:

The above code will return only a few distances (32 as opposed to 213, which is the total number of classes)

I'm not sure what would cause it to return fewer than k results. If I recall correctly, the returned distances is 2d, so I'm wondering if you're doing len(distances), which would be the batch size.

@vemchance
Copy link
Author

Thanks for the reply! I will give your code a shot. In the end, after realising I was misunderstanding the MatchFinder, I wrote something compares a sample to all embeddings and only takes the smallest distance:

tester = BaseTester(normalize_embeddings=True, dataloader_num_workers=0, batch_size=32)
train_embeddings, train_labels = tester.get_all_embeddings(train_dataset, model, eval=True)
val_embeddings, val_labels = tester.get_all_embeddings(val_dataset, model, eval=True)

knn_func = FaissKNN(reset_before=True, reset_after=True)
knn_func.train(train_embeddings)
accuracy_calculator = AccuracyCalculator(knn_func=knn_func)

for i, query_embedding in enumerate(val_embeddings):
    query_embedding = query_embedding.unsqueeze(0)  

    distances, indices = accuracy_calculator.knn_func(
        query=query_embedding,
        k=len(train_embeddings),  
        reference=train_embeddings,
        ref_includes_query=False
    )

    distances = distances[0].tolist() 
    indices = indices[0].tolist()  

    class_distances = {}

    for idx, distance in zip(indices, distances):
        class_label = train_dataset.label_decode[train_labels[idx].item()]

        if class_label not in class_distances:
            class_distances[class_label] = distance
        else:
            class_distances[class_label] = min(class_distances[class_label], distance)

The big drawback is of course this is very slow for large datasets, but it gives me the same accuracy as in training for the individual modalities, and a distance for every class. I will see if your example is a faster though and maybe the accuracy is comparable.

As for the fewer than k results - I think maybe because it was pulling 213 neighbours, but from individual samples and not an entire class, it would overwrite the class name in the dictionary with a new distance. The batch size has no impact, so I think it was the former but I'll double check.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants