diff --git a/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java b/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java index 8985f475..a6116c23 100644 --- a/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java +++ b/tony-core/src/main/java/com/linkedin/tony/ApplicationMaster.java @@ -650,6 +650,8 @@ private Set getUnregisteredTasks() { } private void stop() { + stopRunningContainers(); + FinalApplicationStatus status = session.getFinalStatus(); String appMessage = session.getFinalMessage(); try { @@ -658,8 +660,20 @@ private void stop() { LOG.error("Failed to unregister application", e); } - // stop remaining containers and give them time to finish so we can collect their task metrics and emit a - // TASK_FINISHED event + nmClientAsync.stop(); + amRMClient.stop(); + // Poll until TonyClient signals we should exit + boolean result = Utils.poll(() -> clientSignalToStop, 1, 15); + if (!result) { + LOG.warn("TonyClient didn't signal Tony AM to stop."); + } + } + + /** + * Stops any remaining running containers and gives them time to finish so we can collect their task metrics and emit + * a TASK_FINISHED event. + */ + private void stopRunningContainers() { List allContainers = sessionContainersMap.get(session.sessionId); if (allContainers != null) { for (Container container : allContainers) { @@ -673,16 +687,8 @@ private void stop() { // Give 15 seconds for containers to exit boolean result = Utils.poll(() -> session.getNumCompletedTasks() == session.getTotalTasks(), 1, 15); if (!result) { - LOG.warn("Not all containers were stopped or completed. Only " + session.getNumCompletedTasks() + " out of " + session.getTotalTasks() + " finished."); - } - - nmClientAsync.stop(); - amRMClient.waitForServiceToStop(5000); - amRMClient.stop(); - // Poll until TonyClient signals we should exit - result = Utils.poll(() -> clientSignalToStop, 1, 15); - if (!result) { - LOG.warn("TonyClient didn't signal Tony AM to stop."); + LOG.warn("Not all containers were stopped or completed. Only " + session.getNumCompletedTasks() + " out of " + + session.getTotalTasks() + " finished."); } } @@ -1056,7 +1062,7 @@ public void run() { task.setTaskInfo(container); TaskInfo taskInfo = task.getTaskInfo(); - taskInfo.setState(TaskStatus.READY); + taskInfo.setStatus(TaskStatus.READY); // Add job type specific resources Map containerResources = new ConcurrentHashMap<>(localResources); @@ -1121,7 +1127,7 @@ public void run() { Utils.printTaskUrl(task.getTaskInfo(), LOG); nmClientAsync.startContainerAsync(container, ctx); - taskInfo.setState(TaskStatus.RUNNING); + taskInfo.setStatus(TaskStatus.RUNNING); eventHandler.emitEvent(new Event(EventType.TASK_STARTED, new TaskStarted(task.getJobName(), Integer.parseInt(task.getTaskIndex()), container.getNodeHttpAddress().split(":")[0]), diff --git a/tony-core/src/main/java/com/linkedin/tony/TonyClient.java b/tony-core/src/main/java/com/linkedin/tony/TonyClient.java index c6e6952f..12132f24 100644 --- a/tony-core/src/main/java/com/linkedin/tony/TonyClient.java +++ b/tony-core/src/main/java/com/linkedin/tony/TonyClient.java @@ -129,6 +129,7 @@ public class TonyClient implements AutoCloseable { private Path appResourcesPath; private int hbInterval; private int maxHbMisses; + private boolean isTaskUrlsPrinted = false; private CallbackHandler callbackHandler; private CopyOnWriteArrayList listeners = new CopyOnWriteArrayList<>(); @@ -840,15 +841,13 @@ private ByteBuffer getTokens() throws IOException, URISyntaxException, YarnExcep /** * Monitor the submitted application for completion. * Kill application if time expires. - * @return true if application completed successfully + * @return true if application completed successfully and false otherwise * @throws YarnException * @throws java.io.IOException */ @VisibleForTesting - public boolean monitorApplication() - throws YarnException, IOException, InterruptedException { - - boolean isTaskUrlsPrinted = false; + public boolean monitorApplication() throws YarnException, IOException, InterruptedException { + boolean result; while (true) { // Check app status every 1 second. Thread.sleep(1000); @@ -856,47 +855,31 @@ public boolean monitorApplication() // Get application report for the appId we are interested in ApplicationReport report = yarnClient.getApplicationReport(appId); - YarnApplicationState state = report.getYarnApplicationState(); + YarnApplicationState appState = report.getYarnApplicationState(); - FinalApplicationStatus dsStatus = report.getFinalApplicationStatus(); + FinalApplicationStatus finalApplicationStatus = report.getFinalApplicationStatus(); initRpcClient(report); - if (amRpcClient != null) { - Set receivedInfos = amRpcClient.getTaskInfos(); - Set taskInfoDiff = receivedInfos.stream() - .filter(taskInfo -> !taskInfos.contains(taskInfo)) - .collect(Collectors.toSet()); - // If task status is changed, invoke callback for all listeners. - if (!taskInfoDiff.isEmpty()) { - for (TaskInfo taskInfo : taskInfoDiff) { - LOG.info("Tasks Status Updated: " + taskInfo); - } - for (TaskUpdateListener listener : listeners) { - listener.onTaskInfosUpdated(receivedInfos); - } - taskInfos = receivedInfos; - } + updateTaskInfos(); - // Query AM for taskInfos if taskInfos is empty. - if (amRpcServerInitialized && !isTaskUrlsPrinted) { - if (!taskInfos.isEmpty()) { - // Print TaskUrls - new TreeSet<>(taskInfos).forEach(task -> Utils.printTaskUrl(task, LOG)); - isTaskUrlsPrinted = true; - } - } + if (YarnApplicationState.KILLED == appState) { + LOG.warn("Application " + appId.getId() + " was killed. YarnState: " + appState + ". " + + "FinalApplicationStatus = " + finalApplicationStatus + "."); + // Set amRpcClient to null so client does not try to connect to a killed AM. + amRpcClient = null; + result = false; + break; } - if (YarnApplicationState.FINISHED == state || YarnApplicationState.FAILED == state - || YarnApplicationState.KILLED == state) { - LOG.info("Application " + appId.getId() + " finished with YarnState=" + state.toString() - + ", DSFinalStatus=" + dsStatus.toString() + ", breaking monitoring loop."); - // Set amRpcClient to null so client does not try to connect to it after completion. - amRpcClient = null; + if (YarnApplicationState.FINISHED == appState || YarnApplicationState.FAILED == appState) { + updateTaskInfos(); + LOG.info("Application " + appId.getId() + " finished with YarnState=" + appState + + ", DSFinalStatus=" + finalApplicationStatus + ", breaking monitoring loop."); String tonyPortalUrl = tonyConf.get(TonyConfigurationKeys.TONY_PORTAL_URL, TonyConfigurationKeys.DEFAULT_TONY_PORTAL_URL); Utils.printTonyPortalUrl(tonyPortalUrl, appId.toString(), LOG); - return FinalApplicationStatus.SUCCEEDED == dsStatus; + result = FinalApplicationStatus.SUCCEEDED == finalApplicationStatus; + break; } if (appTimeout > 0) { @@ -904,7 +887,44 @@ public boolean monitorApplication() LOG.info("Reached client specified timeout for application. Killing application" + ". Breaking monitoring loop : ApplicationId:" + appId.getId()); forceKillApplication(); - return false; + result = false; + break; + } + } + } + + if (amRpcClient != null) { + amRpcClient.finishApplication(); + LOG.info("Sent message to AM to stop."); + amRpcClient = null; + } + + return result; + } + + private void updateTaskInfos() throws IOException, YarnException { + if (amRpcClient != null) { + Set receivedInfos = amRpcClient.getTaskInfos(); + Set taskInfoDiff = receivedInfos.stream() + .filter(taskInfo -> !taskInfos.contains(taskInfo)) + .collect(Collectors.toSet()); + // If task status is changed, invoke callback for all listeners. + if (!taskInfoDiff.isEmpty()) { + for (TaskInfo taskInfo : taskInfoDiff) { + LOG.info("Task status updated: " + taskInfo); + } + for (TaskUpdateListener listener : listeners) { + listener.onTaskInfosUpdated(receivedInfos); + } + taskInfos = receivedInfos; + } + + // Query AM for taskInfos if taskInfos is empty. + if (amRpcServerInitialized && !isTaskUrlsPrinted) { + if (!taskInfos.isEmpty()) { + // Print TaskUrls + new TreeSet<>(taskInfos).forEach(task -> Utils.printTaskUrl(task, LOG)); + isTaskUrlsPrinted = true; } } } @@ -1066,7 +1086,7 @@ public void removeListener(TaskUpdateListener listener) { } public static void main(String[] args) { - int exitCode = 0; + int exitCode; // Adds hadoop-inject.xml as a default resource so Azkaban metadata will be present in the new Configuration created HadoopConfigurationInjector.injectResources(new Props() /* ignored */); @@ -1075,16 +1095,10 @@ public static void main(String[] args) { if (!sanityCheck) { LOG.fatal("Failed to init client."); exitCode = -1; - } - - if (exitCode == 0) { + } else { exitCode = client.start(); - if (client.amRpcClient != null) { - client.amRpcClient.finishApplication(); - LOG.info("Sent message to AM to stop."); - } } - } catch (ParseException | IOException | YarnException e) { + } catch (ParseException | IOException e) { LOG.fatal("Encountered exception while initializing client or finishing application.", e); exitCode = -1; } diff --git a/tony-core/src/main/java/com/linkedin/tony/client/TaskUpdateListener.java b/tony-core/src/main/java/com/linkedin/tony/client/TaskUpdateListener.java index 2a2db5e2..3c031c33 100644 --- a/tony-core/src/main/java/com/linkedin/tony/client/TaskUpdateListener.java +++ b/tony-core/src/main/java/com/linkedin/tony/client/TaskUpdateListener.java @@ -10,5 +10,5 @@ public interface TaskUpdateListener { // Called when TonyClient gets a set of taskUrls from TonyAM. - public void onTaskInfosUpdated(Set taskInfoSet); + void onTaskInfosUpdated(Set taskInfoSet); } diff --git a/tony-core/src/main/java/com/linkedin/tony/rpc/TaskInfo.java b/tony-core/src/main/java/com/linkedin/tony/rpc/TaskInfo.java index 3d4177d6..494fe76b 100644 --- a/tony-core/src/main/java/com/linkedin/tony/rpc/TaskInfo.java +++ b/tony-core/src/main/java/com/linkedin/tony/rpc/TaskInfo.java @@ -10,7 +10,7 @@ /** - * Contains the name, index, and URL for a task. + * Contains the name, index, URL, and status for a task. */ public class TaskInfo implements Comparable { private final String name; // The name (worker or ps) of the task @@ -24,7 +24,7 @@ public TaskInfo(String name, String index, String url) { this.url = url; } - public void setState(TaskStatus status) { + public void setStatus(TaskStatus status) { this.status = status; } @@ -75,7 +75,7 @@ public int hashCode() { @Override public String toString() { return String.format( - "[TaskInfo] name: %s index: %s url: %s status: %s", + "[TaskInfo] name: %s, index: %s, url: %s, status: %s", this.name, this.index, this.url, this.status.toString()); } } diff --git a/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java b/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java index 13e4dad8..d5ba697a 100644 --- a/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java +++ b/tony-core/src/main/java/com/linkedin/tony/tensorflow/TonySession.java @@ -233,7 +233,7 @@ public Map> getClusterSpec() { } /** - * Refresh task status on each TaskExecutor registers its exit code with AM. + * Refresh task status when a TaskExecutor registers its exit code with AM. */ public void onTaskCompleted(String jobName, String jobIndex, int exitCode) { LOG.info(String.format("Job %s:%s exited with %d", jobName, jobIndex, exitCode)); @@ -430,17 +430,18 @@ synchronized int getExitStatus() { } synchronized void setExitStatus(int status) { + // Only set exit status if it hasn't been set yet if (exitStatus == -1) { this.exitStatus = status; switch (status) { case ContainerExitStatus.SUCCESS: - taskInfo.setState(TaskStatus.SUCCEEDED); + taskInfo.setStatus(TaskStatus.SUCCEEDED); break; case ContainerExitStatus.KILLED_BY_APPMASTER: - taskInfo.setState(TaskStatus.FINISHED); + taskInfo.setStatus(TaskStatus.FINISHED); break; default: - taskInfo.setState(TaskStatus.FAILED); + taskInfo.setStatus(TaskStatus.FAILED); break; } this.completed = true; diff --git a/tony-core/src/main/java/com/linkedin/tony/util/ProtoUtils.java b/tony-core/src/main/java/com/linkedin/tony/util/ProtoUtils.java index 9a768574..88322cab 100644 --- a/tony-core/src/main/java/com/linkedin/tony/util/ProtoUtils.java +++ b/tony-core/src/main/java/com/linkedin/tony/util/ProtoUtils.java @@ -12,7 +12,7 @@ public class ProtoUtils { public static TaskInfo taskInfoProtoToTaskInfo(TaskInfoProto taskInfoProto) { TaskInfo taskInfo = new TaskInfo(taskInfoProto.getName(), taskInfoProto.getIndex(), taskInfoProto.getUrl()); - taskInfo.setState(TaskStatus.values()[taskInfoProto.getTaskStatus().ordinal()]); + taskInfo.setStatus(TaskStatus.values()[taskInfoProto.getTaskStatus().ordinal()]); return taskInfo; }