Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify CLI command morphoclass train #89

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
morphoclass.console.cmd\_extract\_features\_and\_train module
=============================================================

.. automodule:: morphoclass.console.cmd_extract_features_and_train
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/api/morphoclass.console.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Submodules

morphoclass.console.cmd_evaluate
morphoclass.console.cmd_extract_features
morphoclass.console.cmd_extract_features_and_train
morphoclass.console.cmd_morphometrics
morphoclass.console.cmd_organise_dataset
morphoclass.console.cmd_performance_table
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ python_requires = >=3.8,<3.9
install_requires =
PyYAML
captum
cleanlab==1.0.1
click
dash
dash-bootstrap-components
Expand Down
23 changes: 23 additions & 0 deletions src/morphoclass/console/cmd_extract_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,29 @@ def cli(
no_simplify_graph: bool,
keep_diagram: bool,
force: bool,
) -> None:
"""Extract morphology features."""
return extract_features(
csv_path,
neurite_type,
feature,
output_dir,
orient,
no_simplify_graph,
keep_diagram,
force,
)


def extract_features(
csv_path: StrPath,
neurite_type: str,
feature: str,
output_dir: StrPath,
orient: bool,
no_simplify_graph: bool,
keep_diagram: bool,
force: bool,
) -> None:
"""Extract morphology features."""
output_dir = pathlib.Path(output_dir)
Expand Down
143 changes: 143 additions & 0 deletions src/morphoclass/console/cmd_extract_features_and_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright © 2022-2022 Blue Brain Project/EPFL
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the `morphoclass train` CLI command."""
from __future__ import annotations

import logging
import pathlib
from typing import Literal

import click

from morphoclass.types import StrPath

logger = logging.getLogger(__name__)


@click.command(name="train", help="Train a morphology classification model.")
@click.argument("csv_path", type=click.Path(dir_okay=False))
@click.argument("neurite_type", type=click.Choice(["apical", "axon", "basal", "all"]))
@click.argument(
"feature",
type=click.Choice(
[
"graph-rd",
"graph-proj",
"diagram-tmd-rd",
"diagram-tmd-proj",
"diagram-deepwalk",
"image-tmd-rd",
"image-tmd-proj",
"image-deepwalk",
]
),
)
@click.option(
"--orient",
is_flag=True,
help="Orient the neurons so that the apicals are aligned with the positive y-axis.",
)
@click.option(
"--no-simplify-graph",
is_flag=True,
help="""
By default the neurite graph is reduced to branching nodes only. With this
flag the full neurite graph will be preserved.
""",
)
@click.option(
"--keep-diagram",
is_flag=True,
help="After converting the diagram to persistence image don't discard the diagram.",
)
@click.option(
"--model-config",
type=click.Path(exists=True, dir_okay=False),
required=True,
help="""
The model configuration file.
For inspiration, model configuration files can be found under
dvc/training/configs/
""",
)
@click.option(
"--splitter-config",
type=click.Path(exists=True, dir_okay=False),
required=True,
help="""
The splitter configuration file.
For inspiration, splitter configuration files can be found under
dvc/training/configs/
""",
)
@click.option(
"--output-dir",
type=click.Path(file_okay=False),
required=True,
help="The output directory.",
)
@click.option(
"-f",
"--force",
type=click.BOOL,
default=False,
is_flag=True,
help="Don't ask for overwriting existing output files.",
)
def cli(
csv_path: StrPath,
neurite_type: Literal["apical", "axon", "basal", "all"],
feature: Literal[
"graph-rd",
"graph-proj",
"diagram-tmd-rd",
"diagram-tmd-proj",
"diagram-deepwalk",
"image-tmd-rd",
"image-tmd-proj",
"image-deepwalk",
],
orient: bool,
no_simplify_graph: bool,
keep_diagram: bool,
model_config: StrPath,
splitter_config: StrPath,
output_dir: StrPath,
force: bool,
) -> None:
"""Extract features and train the model."""
from morphoclass.console.cmd_extract_features import extract_features
from morphoclass.console.cmd_train import train

input_csv = pathlib.Path(csv_path).resolve()
output_dir = pathlib.Path(output_dir).resolve()

extract_features(
input_csv,
neurite_type,
feature,
output_dir / "features",
orient,
no_simplify_graph,
keep_diagram,
force,
)

train(
output_dir / "features",
model_config,
splitter_config,
output_dir / "checkpoints",
force,
)
26 changes: 24 additions & 2 deletions src/morphoclass/console/cmd_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of the `morphoclass train` CLI command."""
"""Implementation of the `morphoclass train-after-extraction` CLI command."""
from __future__ import annotations

import logging
Expand All @@ -25,7 +25,12 @@
logger = logging.getLogger(__name__)


@click.command(name="train", help="Train a morphology classification model.")
@click.command(
name="train-after-extraction",
help="""
Train a morphology classification model.
Features need to be first extracted.""",
)
@click.option(
"--features-dir",
type=click.Path(exists=True, file_okay=False),
Expand Down Expand Up @@ -64,6 +69,23 @@ def cli(
splitter_config: StrPath,
checkpoint_dir: StrPath,
force: bool,
) -> None:
"""Training and evaluation of the model."""
return train(
features_dir,
model_config,
splitter_config,
checkpoint_dir,
force,
)


def train(
features_dir: StrPath,
model_config: StrPath,
splitter_config: StrPath,
checkpoint_dir: StrPath,
force: bool,
) -> None:
"""Training and evaluation of the model.

Expand Down
2 changes: 2 additions & 0 deletions src/morphoclass/console/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import morphoclass
from morphoclass.console import cmd_evaluate
from morphoclass.console import cmd_extract_features
from morphoclass.console import cmd_extract_features_and_train
from morphoclass.console import cmd_morphometrics
from morphoclass.console import cmd_organise_dataset
from morphoclass.console import cmd_performance_table
Expand Down Expand Up @@ -143,3 +144,4 @@ def cli(verbose: int, log_file_path: pathlib.Path | None) -> None:
cli.add_command(cmd_performance_table.cli)
cli.add_command(cmd_extract_features.cli)
cli.add_command(cmd_morphometrics.cli)
cli.add_command(cmd_extract_features_and_train.cli)