Skip to content

Commit

Permalink
Add support for translation API
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrad0711 committed Oct 7, 2021
1 parent 7b1a8f0 commit dbaa1a1
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 18 deletions.
37 changes: 37 additions & 0 deletions lib/prediction/localparserclient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
const SEMANTIC_PARSING_TASK = 'almond';
const NLU_TASK = 'almond_dialogue_nlu';
const NLG_TASK = 'almond_dialogue_nlg';
const Translation_TASK = 'almond_translate';
const NLG_QUESTION = 'what should the agent say ?';

export interface LocalParserOptions {
Expand All @@ -59,6 +60,22 @@ function compareScore(a : PredictionCandidate, b : PredictionCandidate) : number
return b.score - a.score;
}

function substringSpan(sequence : string[], substring : string[]) : [number, number] | null {
for (let i=0; i < sequence.length; i++) {
let found = true;
for (let j = 0; j < substring.length; j++) {
if (sequence[i+j] !== substring[j]) {
found = false;
break;
}
}
if (found)
return [i, i + substring.length + 1];
}
return null;
}


export default class LocalParserClient {
private _locale : string;
private _langPack : I18n.LanguagePack;
Expand Down Expand Up @@ -262,4 +279,24 @@ export default class LocalParserClient {
};
});
}

async translateUtterance(input : string[], contextEntities : EntityMap|undefined, translationOptions : Record<string, any>) : Promise<GenerationResult[]> {
if (contextEntities) {
const allEntities = Object.keys(contextEntities).map((ent) => ent.split(' '));
for (const entity of allEntities) {
const span = substringSpan(input, entity);
if (span) {
input.splice(span[0], 0, '"');
input.splice(span[1], 0, '"');
}
}
}
const candidates = await this._predictor.predict('', input.join(' '), undefined, Translation_TASK, 'id-null', translationOptions);
return candidates.map((cand) => {
return {
answer: cand.answer,
score: cand.score.confidence ?? 1
};
});
}
}
32 changes: 17 additions & 15 deletions lib/prediction/predictor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ class LocalWorker extends events.EventEmitter {
this._requests.clear();
}

request(task : string, minibatch : Example[]) : Promise<RawPredictionCandidate[][]> {
request(task : string, minibatch : Example[], options : Record<string, any>) : Promise<RawPredictionCandidate[][]> {
const id = this._nextId ++;

return new Promise((resolve, reject) => {
this._requests.set(id, { resolve, reject });
//console.error(`${this._requests.size} pending requests`);

this._stream!.write({ id, task, instances: minibatch }, (err : Error | undefined | null) => {
this._stream!.write({ id, task, instances: minibatch, options: options }, (err : Error | undefined | null) => {
if (err) {
console.error(err);
reject(err);
Expand All @@ -179,10 +179,11 @@ class RemoteWorker extends events.EventEmitter {
start() {}
stop() {}

async request(task : string, minibatch : Example[]) : Promise<RawPredictionCandidate[][]> {
async request(task : string, minibatch : Example[], options : Record<string, any>) : Promise<RawPredictionCandidate[][]> {
const response = await Tp.Helpers.Http.post(this._url, JSON.stringify({
task,
instances: minibatch
instances: minibatch,
options: options
}), { dataContentType: 'application/json', accept: 'application/json' });
return JSON.parse(response).predictions.map((instance : any) : RawPredictionCandidate[] => {
if (instance.candidates) {
Expand All @@ -209,6 +210,7 @@ export default class Predictor {
private _maxLatency : number;

private _minibatchTask = '';
private _minitbatchOptions = {};
private _minibatch : Example[] = [];
private _minibatchStartTime = 0;

Expand All @@ -225,13 +227,14 @@ export default class Predictor {
private _flushRequest() {
const minibatch = this._minibatch;
const task = this._minibatchTask;
const options = this._minitbatchOptions;

this._minibatch = [];
this._minibatchTask = '';
this._minitbatchOptions = {};
this._minibatchStartTime = 0;

//console.error(`minibatch: ${minibatch.length} instances`);

this._worker!.request(task, minibatch).then((candidates) => {
this._worker!.request(task, minibatch, options).then((candidates) => {
assert(candidates.length === minibatch.length);
for (let i = 0; i < minibatch.length; i++)
minibatch[i].resolve(candidates[i]);
Expand All @@ -241,10 +244,11 @@ export default class Predictor {
});
}

private _startRequest(ex : Example, task : string, now : number) {
private _startRequest(ex : Example, task : string, options : Record<string, any>, now : number) {
assert(this._minibatch.length === 0);
this._minibatch.push(ex);
this._minibatchTask = task;
this._minitbatchOptions = options;
this._minibatchStartTime = now;

setTimeout(() => {
Expand All @@ -253,23 +257,21 @@ export default class Predictor {
}, this._maxLatency);
}

private _addRequest(ex : Example, task : string) {
private _addRequest(ex : Example, task : string, options : Record<string, any>) {
const now = Date.now();
if (this._minibatch.length === 0) {
this._startRequest(ex, task, now);
this._startRequest(ex, task, options, now);
} else if (this._minibatchTask === task &&
(now - this._minibatchStartTime < this._maxLatency) &&
this._minibatch.length < this._minibatchSize) {
this._minibatch.push(ex);
} else {
this._flushRequest();
this._startRequest(ex, task, now);
this._startRequest(ex, task, options, now);
}
}

predict(context : string, question = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string) : Promise<RawPredictionCandidate[]> {
assert(typeof context === 'string');
assert(typeof question === 'string');
predict(context : string, question : string = DEFAULT_QUESTION, answer ?: string, task = 'almond', example_id ?: string, options : Record<string, any> = {}) : Promise<RawPredictionCandidate[]> {

// ensure we have a worker, in case it recently died
if (!this._worker)
Expand All @@ -281,7 +283,7 @@ export default class Predictor {
resolve = _resolve;
reject = _reject;
});
this._addRequest({ context, question, answer, resolve, reject }, task);
this._addRequest({ context, question, answer, example_id, resolve, reject }, task, options);

return promise;
}
Expand Down
72 changes: 69 additions & 3 deletions tool/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ interface Backend {
tokenizer : I18n.BaseTokenizer;
nlu : LocalParserClient;
nlg ?: LocalParserClient;
translator ?: LocalParserClient;
}

declare global {
Expand Down Expand Up @@ -148,6 +149,53 @@ async function queryNLG(params : Record<string, string>,
});
}


interface TranslationData {
input : string;
tgt_locale : string
entities ?: EntityMap;
limit ?: string;
alignment ?: boolean;
src_locale ?: string;
}
const Translation_PARAMS = {
input: 'string',
tgt_locale: 'string',
entities: '?object',
limit: '?number',
alignment: '?boolean',
src_locale: '?string',
};

async function queryTranslate(params : Record<string, string>,
data : TranslationData,
res : express.Response) {
const app = res.app;

if (params.locale !== app.args.locale) {
res.status(400).json({ error: 'Unsupported language' });
return;
}

if (! data.src_locale)
data.src_locale = 'en-US';

const translationOptions : Record<string, any> = {
'do_alignment': data.alignment,
'align_remove_output_quotation': true,
'src_locale': data.src_locale,
'tgt_locale': data.tgt_locale
};

const result = await res.app.backend.translator!.translateUtterance(
data.input.split(' '), data.entities, translationOptions);
res.json({
candidates: result.slice(0, data.limit ? parseInt(data.limit) : undefined),
});
}



export function initArgparse(subparsers : argparse.SubParser) {
const parser = subparsers.add_parser('server', {
add_help: true,
Expand All @@ -159,13 +207,17 @@ export function initArgparse(subparsers : argparse.SubParser) {
default: 8400,
});
parser.add_argument('--nlu-model', {
required: true,
required: false,
help: "Path to the NLU model, pointing to a model directory.",
});
parser.add_argument('--nlg-model', {
required: false,
help: "Path to the NLG model, pointing to a model directory.",
});
parser.add_argument('--translation-model', {
required: false,
help: "Path to the Translation model, pointing to a model directory.",
});
parser.add_argument('--thingpedia', {
required: true,
help: 'Path to ThingTalk file containing class definitions.'
Expand Down Expand Up @@ -199,14 +251,21 @@ export async function execute(args : any) {
tokenizer: i18n.getTokenizer(),
nlu: new LocalParserClient(args.nlu_model, args.locale, undefined, undefined, tpClient)
};
app.backend.nlu.start();

if (args.nlu_model)
app.backend.nlu.start();
if (args.nlg_model && args.nlg_model !== args.nlu_model) {
app.backend.nlg = new LocalParserClient(args.nlg_model, args.locale, undefined, undefined, tpClient);
app.backend.nlg.start();
} else {
app.backend.nlg = app.backend.nlu;
}

if (args.translation_model) {
app.backend.translator = new LocalParserClient(args.translation_model, args.locale, undefined, undefined, tpClient);
app.backend.translator.start();
}

app.args = args;

app.set('port', args.port);
Expand All @@ -227,6 +286,10 @@ export async function execute(args : any) {
queryNLG(req.params, req.body, res).catch(next);
});

app.post('/:locale/translate', qv.validatePOST(Translation_PARAMS, { accept: 'application/json' }), (req, res, next) => {
queryTranslate(req.params, req.body, res).catch(next);
});

app.post('/:locale/tokenize', qv.validatePOST({ q: 'string', entities: '?object' }, { accept: 'application/json' }), (req, res, next) => {
tokenize(req.params, req.body, res).catch(next);
});
Expand All @@ -246,8 +309,11 @@ export async function execute(args : any) {
process.on('SIGTERM', resolve);
});

await app.backend.nlu.stop();
if (app.backend.nlu)
await app.backend.nlu.stop();
if (app.backend.nlg !== app.backend.nlu)
await app.backend.nlg.stop();
if (app.backend.translator)
await app.backend.translator.stop();
server.close();
}

0 comments on commit dbaa1a1

Please sign in to comment.