Skip to content

Commit

Permalink
Merge pull request #21 from PAIR-code/iislucas-2024-08-23-fix-run-script
Browse files Browse the repository at this point in the history
fix transformer run script on toy worlds
  • Loading branch information
iislucas authored Aug 23, 2024
2 parents 550e394 + bb8a7f3 commit a52ffb3
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 41 deletions.
150 changes: 145 additions & 5 deletions animated-transformer/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions animated-transformer/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"watch": "ng build --watch --configuration development",
"test": "ng test --watch"
},
"node": ">=20.11.1",
"node": ">=22.7.0",
"private": true,
"dependencies": {
"@angular/animations": "^18.0.1",
Expand All @@ -33,6 +33,7 @@
"json5": "^2.2.3",
"msgpackr": "^1.11.0",
"rxjs": "~7.5.0",
"ts-node": "^10.9.2",
"tslib": "^2.3.0",
"underscore": "^1.13.6",
"yargs": "^17.7.2",
Expand All @@ -57,4 +58,4 @@
"karma-jasmine-html-reporter": "~2.0.0",
"typescript": "~5.5.0"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.
Tiny Worlds, run with (gtensor-based) transformers.
Run:
ts-node src/lib/seqtasks/tiny_worlds.run_with_transformer.script.ts
npx ts-node src/lib/seqtasks/tiny_worlds.run_with_transformer.script.ts
*/

Expand Down Expand Up @@ -55,7 +55,6 @@ import {
prepareBasicTaskTokenRep,
BasicTaskTokenRep,
} from '../tokens/token_gemb';
import { GTensorTree, GVariableTree } from '../gtensor/gtensor_tree';
import { layer } from '@tensorflow/tfjs-vis/dist/show/model';
import { example } from 'yargs';

Expand Down Expand Up @@ -97,11 +96,7 @@ function getTransformerConfig(): TransformerConfig {
return config;
}

function* dataGenerator(
task: TinyWorldTask,
batchNum: number,
batchSize: number
) {
function* dataGenerator(task: TinyWorldTask, batchNum: number, batchSize: number) {
for (let batchId = 0; batchId < batchNum; batchId += 1) {
let batchOriginal = task.exampleIter.takeOutN(batchSize);
let batchInput = batchOriginal.map((example) => example.input);
Expand All @@ -115,7 +110,7 @@ function unbindedLossFn(
batchOutput: string[][],
tokenRep: BasicTaskTokenRep,
transformerConfig: TransformerConfig,
decoderParamsTree: GVariableTree<TransformerParams>
decoderParamsTree: TransformerParams
): tf.Scalar {
let spec = transformerConfig.spec;
let computation: TransformerComputation = computeDecoder(
Expand All @@ -125,18 +120,15 @@ function unbindedLossFn(
decoderParamsTree,
batchInput
);
let singleNextTokenIdx = singleNextTokenIdxOutputPrepFn(
tokenRep,
batchOutput
);
let singleNextTokenIdx = singleNextTokenIdxOutputPrepFn(tokenRep, batchOutput);
let entropyLoss: tf.Scalar = transformerLastTokenCrossEntropyLoss(
computation,
decoderParamsTree.obj.tokenEmbedding,
decoderParamsTree.tokenEmbedding,
singleNextTokenIdx
);
let accuracy: tf.Scalar = transformerAccuracy(
computation,
decoderParamsTree.obj.tokenEmbedding,
decoderParamsTree.tokenEmbedding,
singleNextTokenIdx
);

Expand Down Expand Up @@ -171,13 +163,7 @@ function unbindedLossFn(

let [batchInput, batchOutput] = batch;
let bindedLossFn = () =>
unbindedLossFn(
batchInput,
batchOutput,
tokenRep,
transformerConfig,
decoderParamsTree
);
unbindedLossFn(batchInput, batchOutput, tokenRep, transformerConfig, decoderParamsTree);
optimizer.minimize(bindedLossFn);
}

Expand All @@ -200,16 +186,9 @@ function unbindedLossFn(
decoderParamsTree,
batchInput
);
let singleNextTokenIdx = singleNextTokenIdxOutputPrepFn(
tokenRep,
batchOutput
);
let singleNextTokenIdxArrayData =
singleNextTokenIdx.tensor.arraySync() as number[];
let logits = transformerLastTokenLogits(
computation,
decoderParamsTree.obj.tokenEmbedding
);
let singleNextTokenIdx = singleNextTokenIdxOutputPrepFn(tokenRep, batchOutput);
let singleNextTokenIdxArrayData = singleNextTokenIdx.tensor.arraySync() as number[];
let logits = transformerLastTokenLogits(computation, decoderParamsTree.tokenEmbedding);
let probs = logits.softmax('tokenId');
let probsArrayData = probs.tensor.arraySync() as number[][];

Expand Down Expand Up @@ -237,8 +216,6 @@ function unbindedLossFn(
batchInput = batchInput.map((subArray, batchIndex) =>
subArray.slice(1).concat(batchOutput[batchIndex])
);
batchOutput = batchOutputAll.map((subArray) =>
subArray.slice(inferStep, inferStep + 1)
);
batchOutput = batchOutputAll.map((subArray) => subArray.slice(inferStep, inferStep + 1));
}
}

0 comments on commit a52ffb3

Please sign in to comment.