Skip to content

Commit

Permalink
Update to ML 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewdalpino committed May 8, 2021
1 parent bbb78e2 commit 765b1b7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 19 deletions.
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ $ composer create-project rubix/iris
```

## Requirements
- [PHP](https://php.net) 7.2 or above
- [PHP](https://php.net) 7.4 or above

## Tutorial

Expand Down Expand Up @@ -74,21 +74,13 @@ use Rubix\ML\CrossValidation\Metrics\Accuracy;
$metric = new Accuracy();

$score = $metric->score($predictions, $testing->labels());

echo 'Accuracy is ' . (string) ($score * 100.0) . '%' . PHP_EOL;
```

Now you're ready to run the training script from the command line.
```sh
php train.php
```

**Output**

```sh
Accuracy is 90%
```

### Next Steps
Congratulations on completing the introduction to machine learning in PHP with Rubix ML using the Iris dataset. Now you're ready to experiment on your own. For example, you may want to try different values of `k` or swap out the default [Euclidean](https://docs.rubixml.com/latest/kernels/distance/euclidean.html) distance kernel for another one such as [Manhattan](https://docs.rubixml.com/latest/kernels/distance/manhattan.html) or [Minkowski](https://docs.rubixml.com/latest/kernels/distance/minkowski.html).

Expand Down
4 changes: 2 additions & 2 deletions composer.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
}
],
"require": {
"php": ">=7.2",
"rubix/ml": "^0.3.0"
"php": ">=7.4",
"rubix/ml": "^1.0"
},
"scripts": {
"train": "@php train.php"
Expand Down
15 changes: 7 additions & 8 deletions train.php
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,32 @@

include __DIR__ . '/vendor/autoload.php';

use Rubix\ML\Loggers\Screen;
use Rubix\ML\Datasets\Labeled;
use Rubix\ML\Extractors\NDJSON;
use Rubix\ML\Classifiers\KNearestNeighbors;
use Rubix\ML\CrossValidation\Metrics\Accuracy;

echo 'Loading data into memory ...' . PHP_EOL;
$logger = new Screen();

$logger->info('Loading data into memory');

$training = Labeled::fromIterator(new NDJSON('dataset.ndjson'));

$testing = $training->randomize()->take(10);

$estimator = new KNearestNeighbors(5);

echo 'Training ...' . PHP_EOL;
$logger->info('Training');

$estimator->train($training);

echo 'Making predictions ...' . PHP_EOL;
$logger->info('Making predictions');

$predictions = $estimator->predict($testing);

echo 'Example predictions:' . PHP_EOL;

print_r(array_slice($predictions, 0, 3));

$metric = new Accuracy();

$score = $metric->score($predictions, $testing->labels());

echo 'Accuracy is ' . (string) ($score * 100.0) . '%' . PHP_EOL;
$logger->info("Accuracy is $score");

0 comments on commit 765b1b7

Please sign in to comment.