Skip to content

Commit

Permalink
Address PR comments (2)
Browse files Browse the repository at this point in the history
- declare a proper type for generation arguments that user can override when calling the genienlp parser
- some refactoring for better reading
  • Loading branch information
Mehrad0711 committed Oct 7, 2021
1 parent 58edbfb commit 457507e
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 50 deletions.
9 changes: 5 additions & 4 deletions lib/prediction/localparserclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ import {
PredictionCandidate,
PredictionResult,
GenerationResult,
ExactMatcher
ExactMatcher,
GENERATION_OPTIONS
} from './types';

const SEMANTIC_PARSING_TASK = 'almond';
Expand Down Expand Up @@ -264,9 +265,9 @@ export default class LocalParserClient {
});
}

async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record<string, unknown>) : Promise<GenerationResult[]> {
input = Utils.qpisEntities(input, contextEntities);
const candidates = await this._predictor.predict('', input.join(' '), undefined, TRANSLATION_TASK, 'id-null', translationOptions);
async translateUtterance(input : string[], entities : string[]|undefined, generationOptions : GENERATION_OPTIONS) : Promise<GenerationResult[]> {
input = Utils.qpisEntities(input, entities);
const candidates = await this._predictor.predict('', input.join(' '), undefined, TRANSLATION_TASK, 'id-null', generationOptions);
return candidates.map((cand) => {
return {
answer: cand.answer,
Expand Down
11 changes: 6 additions & 5 deletions lib/prediction/predictor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import * as child_process from 'child_process';
import * as Tp from 'thingpedia';

import JsonDatagramSocket from '../utils/json_datagram_socket';
import {GENERATION_OPTIONS} from "./types";

const DEFAULT_QUESTION = 'translate from english to thingtalk';

Expand Down Expand Up @@ -151,7 +152,7 @@ class LocalWorker extends events.EventEmitter {
this._requests.clear();
}

request(task : string, minibatch : Example[], options : Record<string, unknown>) : Promise<RawPredictionCandidate[][]> {
request(task : string, minibatch : Example[], options : GENERATION_OPTIONS) : Promise<RawPredictionCandidate[][]> {
const id = this._nextId ++;

return new Promise((resolve, reject) => {
Expand Down Expand Up @@ -179,7 +180,7 @@ class RemoteWorker extends events.EventEmitter {
start() {}
stop() {}

async request(task : string, minibatch : Example[], options : Record<string, unknown>) : Promise<RawPredictionCandidate[][]> {
async request(task : string, minibatch : Example[], options : GENERATION_OPTIONS) : Promise<RawPredictionCandidate[][]> {
const response = await Tp.Helpers.Http.post(this._url, JSON.stringify({
task,
instances: minibatch,
Expand Down Expand Up @@ -244,7 +245,7 @@ export default class Predictor {
});
}

private _startRequest(ex : Example, task : string, options : Record<string, unknown>, now : number) {
private _startRequest(ex : Example, task : string, options : GENERATION_OPTIONS, now : number) {
assert(this._minibatch.length === 0);
this._minibatch.push(ex);
this._minibatchTask = task;
Expand All @@ -257,7 +258,7 @@ export default class Predictor {
}, this._maxLatency);
}

private _addRequest(ex : Example, task : string, options : Record<string, unknown>) {
private _addRequest(ex : Example, task : string, options : GENERATION_OPTIONS) {
const now = Date.now();
if (this._minibatch.length === 0) {
this._startRequest(ex, task, options, now);
Expand All @@ -271,7 +272,7 @@ export default class Predictor {
}
}

predict(context : string, question : string = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string, options : Record<string, unknown> = {}) : Promise<RawPredictionCandidate[]> {
predict(context : string, question : string = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string, options : GENERATION_OPTIONS = {}) : Promise<RawPredictionCandidate[]> {

// ensure we have a worker, in case it recently died
if (!this._worker)
Expand Down
19 changes: 11 additions & 8 deletions lib/prediction/remoteparserclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import {
PredictionCandidate,
PredictionResult,
GenerationResult,
ExactMatcher,
ExactMatcher, GENERATION_OPTIONS,
} from './types';
import ExactMatcherBuilder from './exactbuilder';

Expand Down Expand Up @@ -170,16 +170,19 @@ export default class RemoteParserClient {
return parsed.candidates;
}

async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record<string, unknown>) : Promise<GenerationResult[]> {
input = Utils.qpisEntities(input, contextEntities);
async translateUtterance(input : string[], entities : string[]|undefined, generationOptions : GENERATION_OPTIONS) : Promise<GenerationResult[]> {
input = Utils.qpisEntities(input, entities);

const data = {
input: input.join(' '),
tgt_locale: translationOptions.tgt_locale,
entities: contextEntities,
alignment: translationOptions.alignment,
src_locale: translationOptions.src_locale,
align_remove_output_quotation: translationOptions.align_remove_output_quotation
tgt_locale: generationOptions.tgt_locale,
entities: entities,
alignment: generationOptions.do_alignment,
src_locale: generationOptions.src_locale,
// always remove quotation marks in the output string used to mark entity boundaries
align_remove_output_quotation: true,
// always break input utterance into individual sentences before translation
translate_example_split: true
};

const response = await Tp.Helpers.Http.post(`${this._baseUrl}/translate`, JSON.stringify(data), {
Expand Down
20 changes: 20 additions & 0 deletions lib/prediction/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,26 @@ export interface ExactMatcher {
get(tokens : string[]) : string[][]|null;
}

export interface GENERATION_OPTIONS {
num_beams ?: number
num_beam_groups ?: number
diversity_penalty ?: number
num_outputs ?: number
no_repeat_ngram_size ?: number
top_p ?: number
top_k ?: number
repetition_penalty ?: number
temperature ?: number
max_output_length ?: number
src_locale ?: string
tgt_locale ?: string
do_alignment ?: boolean
align_preserve_input_quotation ?: boolean
align_remove_output_quotation ?: boolean
translate_example_split ?: boolean
}


export interface ParseOptions {
thingtalk_version ?: string;
store ?: string;
Expand Down
15 changes: 8 additions & 7 deletions lib/utils/misc-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import {
makeDummyEntity,
makeDummyEntities,
renumberEntities,
EntityMap,
} from './entity-utils';

class ValidationError extends Error {
Expand Down Expand Up @@ -211,20 +210,22 @@ function substringSpan(sequence : string[], substring : string[]) : [number, num
}
}
if (found)
return [i, i + substring.length + 1];
return [i, i + substring.length];
}
return null;
}


function qpisEntities(input : string[], contextEntities : EntityMap|undefined) : string[] {
if (contextEntities) {
const allEntities = Object.keys(contextEntities).map((ent) => ent.split(' '));
for (const entity of allEntities) {
function qpisEntities(input : string[], entities : string[]|undefined) : string[] {
if (entities) {
const entityTokens = entities.map((ent) => ent.split(' '));
for (const entity of entityTokens) {
const span = substringSpan(input, entity);
if (span) {
input.splice(span[0], 0, '"');
input.splice(span[1], 0, '"');

// add 1 cause previous splice shift tokens to the right
input.splice(span[1] + 1, 0, '"');
}
}
}
Expand Down
31 changes: 5 additions & 26 deletions tool/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import * as Utils from '../lib/utils/misc-utils';
import { EntityMap } from '../lib/utils/entity-utils';
import LocalParserClient from '../lib/prediction/localparserclient';
import * as I18n from '../lib/i18n';
import {GENERATION_OPTIONS} from "../lib/prediction/types";

interface Backend {
schemas : ThingTalk.SchemaRetriever;
Expand Down Expand Up @@ -153,7 +154,7 @@ async function queryNLG(params : Record<string, string>,
interface TranslationData {
input : string;
tgt_locale : string
entities ?: EntityMap;
entities ?: string[];
limit ?: string;
alignment ?: boolean;
src_locale ?: string;
Expand All @@ -162,35 +163,13 @@ interface TranslationData {
const Translation_PARAMS = {
input: 'string',
tgt_locale: 'string',
entities: '?object',
entities: '?array',
limit: '?number',
alignment: '?boolean',
src_locale: '?string',
align_remove_output_quotation: '?boolean'
};


// const VALID_PARSER_OPTIONS = new Set([
// "num_beams",
// "num_beam_groups",
// "diversity_penalty",
// "num_outputs",
// "no_repeat_ngram_size",
// "top_p",
// "top_k",
// "repetition_penalty",
// "temperature",
// "max_output_length",
// "reduce_metrics",
// "database_dir",
// "do_alignment",
// "align_preserve_input_quotation",
// "align_remove_output_quotation",
// "src_locale",
// "tgt_locale",
// "translate_example_split"
// ]);

async function queryTranslate(params : Record<string, string>,
data : TranslationData,
res : express.Response) {
Expand All @@ -204,7 +183,7 @@ async function queryTranslate(params : Record<string, string>,
if (!data.src_locale)
data.src_locale = 'en-US';

const translationOptions : Record<string, unknown> = {
const generationOptions : GENERATION_OPTIONS = {
'src_locale': data.src_locale,
'tgt_locale': data.tgt_locale,
'do_alignment': data.alignment,
Expand All @@ -213,7 +192,7 @@ async function queryTranslate(params : Record<string, string>,
};

const result = await res.app.backend.translator!.translateUtterance(
data.input.split(' '), data.entities, translationOptions);
data.input.split(' '), data.entities, generationOptions);
res.json({
candidates: result.slice(0, data.limit ? parseInt(data.limit) : undefined),
});
Expand Down

0 comments on commit 457507e

Please sign in to comment.