-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes #726: Implemented two accuracy functions #842
Conversation
All contributors have signed the CLA ✍️ ✅ |
Can anyone tell me how I can sign the CLA,please? I commented “I have read the CLA Document and I hereby sign the CLA”,but the CLA Assistant Lite bot won‘t pass. |
Hi @1160300918 The commits and this pull request are created by different accounts, you have to sign the CLA with both accounts. |
I have read the CLA Document and I hereby sign the CLA |
1 similar comment
I have read the CLA Document and I hereby sign the CLA |
Thanks. The bot passed, but the CLA Assistant in workflows still failed, should I commit again to activate it? |
ignore that. |
OK, is there anything else I should do besides waiting for the code review results? |
just be patient |
def balanced_accuracy_score(y_true, y_pred, labels, sample_weight=None, adjusted=False): | ||
"""calculate balanced accuracy score""" | ||
C = confusion_matrix(y_true, y_pred, labels, sample_weight=sample_weight) | ||
with np.errstate(divide="ignore", invalid="ignore"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The context management can not work in SPU, can just delete this
|
||
from sml.preprocessing.preprocessing import label_binarize | ||
from spu.ops.groupby import groupby, groupby_sum | ||
|
||
from .auc import binary_clf_curve, binary_roc_auc | ||
|
||
|
||
def confusion_matrix(y_true, y_pred, labels, sample_weight=None, normalize=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not implement the logic about sample_weight
and normalize
cm = jnp.zeros((num_labels, num_labels), dtype=jnp.int32) | ||
|
||
# Calculate the confusion matrix | ||
for i, label in enumerate(labels): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Vectorized operations can replace these two for-loops, So the round complexity can be reduced.
e.g.
y_true == labels
gives an n*c matrix, where n is the number of samples, and c is the number of labels.
Same as y_pred == labels
, then the cm is just the inner product of all column-pairs of two matrics.
|
||
from sml.preprocessing.preprocessing import label_binarize | ||
from spu.ops.groupby import groupby, groupby_sum | ||
|
||
from .auc import binary_clf_curve, binary_roc_auc | ||
|
||
|
||
def confusion_matrix(y_true, y_pred, labels, sample_weight=None, normalize=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some docs for the functionality of this function, and the means of all params.
return cm | ||
|
||
|
||
def balanced_accuracy_score(y_true, y_pred, labels, sample_weight=None, adjusted=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add some docs for the functionality of this function, and the means of all params.
return top_k_score | ||
|
||
def check(spu_result, sk_result): | ||
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rtol
and atol
can be set to 1e-3
return balanced_score | ||
|
||
def check(spu_result, sk_result): | ||
np.testing.assert_allclose(spu_result, sk_result, rtol=1, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rtol
and atol
can be set to 1e-3
|
||
def test_balanced_accuracy(self): | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FM64
is enough
|
||
def test_top_k_accuracy(self): | ||
sim = spsim.Simulator.simple( | ||
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FM64
is enough
y_true = jnp.array([0, 1, 1, 0, 1, 1]) | ||
y_pred = jnp.array([0, 0, 1, 0, 1, 1]) | ||
labels = jnp.array([0, 1]) | ||
spu_result = spsim.sim_jax(sim, proc)(y_true, y_pred, labels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test the param sample_weight
and adjusted
;
Test larger datasets, please (maybe ~1000 samples are enough).
Stale pull request message. Please comment to remove stale tag. Otherwise this pr will be closed soon. |
Stale pull request message. Please comment to remove stale tag. Otherwise this pr will be closed soon. |
Stale pull request message. Please comment to remove stale tag. Otherwise this pr will be closed soon. |
Pull Request
What problem does this PR solve?
Issue Number: Fixed #726
Implemented two accuracy functions for binary classification and multi-class classification.