Skip to content

Commit

Permalink
remove pin update tree_method
Browse files Browse the repository at this point in the history
  • Loading branch information
robsdavis committed Sep 13, 2024
1 parent 4380b4a commit c5fb4ba
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 6 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/test_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,12 @@ jobs:
python -m pip install -U pip
pip install -r prereq.txt
pip install pytest-timeout
pip install xgboost==1.7.6
- name: Set environment variables for MNIST
run: echo "MNIST_DATA_DIR=${{ github.workspace }}/mnist_data" >> $GITHUB_ENV
- name: Dump GitHub Action environment
run: |
echo "OS: $(uname -a)"
python --version
pip freeze
- name: Test Core
run: |
pip install .[testing]
Expand Down
3 changes: 1 addition & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ install_requires =
pydantic<2.0
cloudpickle
scipy
xgboost<3.0.0; platform_system != "Darwin"
xgboost==1.7.6; platform_system == "Darwin"
xgboost<3.0.0
geomloss
pgmpy
redis
Expand Down
3 changes: 2 additions & 1 deletion src/synthcity/metrics/eval_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,8 +1064,9 @@ def evaluate(

elif self._task_type == "classification":
model = XGBClassifier(
tree_method="approx",
n_jobs=2,
verbosity=0,
verbosity=2,
depth=3,
random_state=self._random_state,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _fit(
raise ValueError(f"unsupported strategy {self.strategy}")

xgb_params = {
"tree_method": "approx",
"n_jobs": 2,
"verbosity": 0,
"depth": 3,
Expand Down
7 changes: 6 additions & 1 deletion src/synthcity/utils/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ def compress_dataset(
y = df[column]

if len(df[column].unique()) < cat_limit:
model = XGBClassifier()
model = model = XGBClassifier(
tree_method="approx",
n_jobs=2,
verbosity=2,
depth=3,
)
try:
score = evaluate_classifier(model, X, y)["clf"]["aucroc"][0]
except BaseException:
Expand Down

0 comments on commit c5fb4ba

Please sign in to comment.