Skip to content

Commit

Permalink
fix(lib): Fix splitTest not correctly translating to test data length (
Browse files Browse the repository at this point in the history
…#17)

* refactor(lib): Liberate splitTestData and improve its tests

* fix(lib): Fix splitTest not correctly translating to test data length
  • Loading branch information
isair authored Aug 7, 2020
1 parent e9646b3 commit d5d057d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 18 deletions.
19 changes: 1 addition & 18 deletions src/loadCsv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,10 @@ import { shuffle } from 'shuffle-seed';

import { CsvReadOptions, CsvTable } from './loadCsv.models';
import filterColumns from './filterColumns';
import splitTestData from './splitTestData';

const defaultShuffleSeed = 'mncv9340ur';

const splitTestData = (
features: CsvTable,
labels: CsvTable,
splitTest: boolean | number
) => {
const length =
typeof splitTest === 'number'
? Math.max(0, Math.min(splitTest, features.length - 1))
: Math.floor(features.length / 2);

return {
testFeatures: features.slice(length),
testLabels: labels.slice(length),
features: features.slice(0, length),
labels: labels.slice(0, length),
};
};

const loadCsv = (filename: string, options: CsvReadOptions) => {
const {
featureColumns,
Expand Down
23 changes: 23 additions & 0 deletions src/splitTestData.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import { CsvTable } from './loadCsv.models';

const splitTestData = (
features: CsvTable,
labels: CsvTable,
splitTest: true | number
) => {
const dataLength = features.length;
const testLength =
typeof splitTest === 'number'
? Math.max(0, Math.min(splitTest, dataLength))
: Math.floor(features.length / 2);
const testStartIndex = dataLength - testLength;

return {
features: features.slice(0, testStartIndex),
labels: labels.slice(0, testStartIndex),
testFeatures: features.slice(testStartIndex),
testLabels: labels.slice(testStartIndex),
};
};

export default splitTestData;
81 changes: 81 additions & 0 deletions tests/splitTestData.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import splitTestData from '../src/splitTestData';

const tables = {
features: [
[1, 2],
[3, 4],
[5, 6],
[7, 8],
],
labels: [[9], [10], [11], [12]],
};

test('Default splitting, splits in half', () => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
true
);
expect(features).toMatchObject([
[1, 2],
[3, 4],
]);
expect(labels).toMatchObject([[9], [10]]);
expect(testFeatures).toMatchObject([
[5, 6],
[7, 8],
]);
expect(testLabels).toMatchObject([[11], [12]]);
});

test('Splitting a fixed amount works', () => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
1
);
expect(features).toMatchObject([
[1, 2],
[3, 4],
[5, 6],
]);
expect(labels).toMatchObject([[9], [10], [11]]);
expect(testFeatures).toMatchObject([[7, 8]]);
expect(testLabels).toMatchObject([[12]]);
});

test('Splitting more than row length splits all rows into test data', () => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
tables.features.length * 2
);
expect(features).toMatchObject([]);
expect(labels).toMatchObject([]);
expect(testFeatures).toMatchObject([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
]);
expect(testLabels).toMatchObject([[9], [10], [11], [12]]);
});

test('Splitting less than or equal to 0 places all rows into normal data', () => {
[0, -1].forEach((splitLength) => {
const { features, labels, testFeatures, testLabels } = splitTestData(
tables.features,
tables.labels,
splitLength
);
expect(features).toMatchObject([
[1, 2],
[3, 4],
[5, 6],
[7, 8],
]);
expect(labels).toMatchObject([[9], [10], [11], [12]]);
expect(testFeatures).toMatchObject([]);
expect(testLabels).toMatchObject([]);
});
});

0 comments on commit d5d057d

Please sign in to comment.