Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sample_weight with multiple, named losses #20992

Open
ZankoNT opened this issue Mar 6, 2025 · 0 comments
Open

sample_weight with multiple, named losses #20992

ZankoNT opened this issue Mar 6, 2025 · 0 comments
Assignees

Comments

@ZankoNT
Copy link

ZankoNT commented Mar 6, 2025

Hello!

Today I was trying to train a model using jax backend with multiple losses, while using the sample_weight argument, which seem to expect something like:

self.model.compile(
    optimizer=keras.optimizers.AdamW(learning_rate=self.learning_rate),
    loss={
        'policy_output': 'categorical_crossentropy',
        'value_output': 'mse'
    },
    loss_weights={'policy_output': np.mean(value_targets) / np.mean(policy_targets), 'value_output': 1.0}
)

return self.model.fit(
    x=states, 
    y={'policy_output': policy_targets, 'value_output': value_targets},
    sample_weight={'policy_output': np.ones(len(states)), 'value_output': value_confidence,},
    **kwargs
)

However, I got a KeyError referencing a section of method call in trainers/compile_utils.py. It was trying to look for key 0 but failing to find it in line

_sample_weight = resolve_path(path, sample_weight)

By looking at it, it seems to me like resolve_path was written with self._flat_losses in mind, taking for granted that path is something accessible. However, in my case, sample_weight was not preprocessed in a way for it to have a path attribute, the structure was as I had passed it to fit (in my case, a vanilla dict with the traced stuff inside).

I suppose a good way to fix it would be adding the path to sample_weight in an appropriate spot, but for now I resulted in editing the library code to do something like:

for (path, loss_fn, loss_weight, loss_name), metric in zip(
    self._flat_losses, metrics
):
    y_t, y_p = resolve_path(path, y_true), resolve_path(path, y_pred)
    if sample_weight is not None and tree.is_nested(sample_weight):
        _sample_weight = sample_weight[loss_name]
    else:
        _sample_weight = sample_weight

Which is pretty coarse but works for me.

Just wanted to give a heads-up on it if it's a real issue, otherwise if I'm doing something comically wrong I apologize for the time wasted!

python 3.11.11
keras 3.9
jax 0.5.2

Have a great day!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants