diff --git a/pom.xml b/pom.xml
index 99ab099..208a13a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -37,6 +37,12 @@
4.8.2
test
+
+ org.mockito
+ mockito-all
+ 1.9.0
+ test
+
org.scala-tools.testing
specs_${scala.version}
diff --git a/src/main/scala/scalang/epmd/Epmd.scala b/src/main/scala/scalang/epmd/Epmd.scala
index df994c6..90580d2 100644
--- a/src/main/scala/scalang/epmd/Epmd.scala
+++ b/src/main/scala/scalang/epmd/Epmd.scala
@@ -21,30 +21,81 @@ import netty.bootstrap._
import netty.channel._
import socket.nio._
import overlock.threadpool._
+import com.codahale.logula.Logging
-object Epmd {
+case class EpmdConfig(
+ host: String,
+ port: Int,
+ connectOnInit: Boolean = true,
+ retries: Option[Int] = Some(Epmd.defaultRetries),
+ retryInterval: Option[Double] = None,
+ connectionTimeout: Option[Int] = None
+)
+
+object Epmd extends Logging {
+ private val futurePollingInterval = 100 /* nanos */
val defaultPort = 4369
+ val defaultRetries = 5
+ val defaultRetryInterval = 2.0 /*seconds*/
+
lazy val bossPool = ThreadPool.instrumentedElastic("scalang.epmd", "boss", 1, 20)
lazy val workerPool = ThreadPool.instrumentedElastic("scalang.epmd", "worker", 1, 20)
- def apply(host : String) : Epmd = {
- val port = Option(System.getenv("ERL_EPMD_PORT")).map(_.toInt).getOrElse(defaultPort)
- new Epmd(host, port)
+ def apply(host: String, port: Option[Int] = None): Epmd = {
+ val epmdPort = port match {
+ case Some(p) => p
+ case None => Option(System.getenv("ERL_EPMD_PORT")).map(_.toInt).getOrElse(defaultPort)
+ }
+ Epmd(new EpmdConfig(host, epmdPort))
}
- def apply(host : String, port : Int) : Epmd = {
- new Epmd(host, port)
+ def apply(cfg: EpmdConfig): Epmd = {
+ val epmd = new Epmd(cfg.host, cfg.port, cfg.connectionTimeout)
+ if (cfg.connectOnInit) {
+ connectWithRetries(epmd, cfg)
+ }
+ epmd
+ }
+
+ def connectWithRetries(epmd: Epmd, cfg: EpmdConfig) {
+ var future = epmd.connect
+ while (!future.isDone) {
+ Thread.sleep(futurePollingInterval)
+ }
+
+ if (cfg.retries.isDefined) {
+ val retries = cfg.retries.get
+
+ val retryInterval = cfg.retryInterval.getOrElse(defaultRetryInterval)
+ val retryIntervalMillis = (retryInterval * 1000.0).toInt
+
+ var numRetries = 0
+ while (!epmd.connected && numRetries < retries) {
+ // Retry the connection
+ if (!future.isSuccess) {
+ log.warn("epmd connection failed. Retrying in %.1f seconds", retryInterval)
+ Thread.sleep(retryIntervalMillis)
+ future = epmd.connect
+ numRetries += 1
+ }
+ // Poll the future
+ while (!future.isDone) {
+ Thread.sleep(futurePollingInterval)
+ }
+ }
+ }
}
}
-class Epmd(val host : String, val port : Int) {
+class Epmd(val host : String, val port : Int, val defaultTimeout: Option[Int] = None) extends Logging {
+ var channel: Channel = null
+ val handler = new EpmdHandler
+
val bootstrap = new ClientBootstrap(
new NioClientSocketChannelFactory(
Epmd.bossPool,
Epmd.workerPool))
- val handler = new EpmdHandler
-
bootstrap.setPipelineFactory(new ChannelPipelineFactory {
def getPipeline : ChannelPipeline = {
Channels.pipeline(
@@ -53,29 +104,63 @@ class Epmd(val host : String, val port : Int) {
handler)
}
})
+ setTimeout(defaultTimeout)
+
+
+ def setTimeout(timeout: Option[Int]) {
+ if (timeout.isDefined) {
+ bootstrap.setOption("connectTimeoutMillis", timeout.get * 1000)
+ }
+ }
+
+ def connect: ChannelFuture = {
+ val connectFuture = bootstrap.connect(new InetSocketAddress(host, port))
+ connectFuture.addListener(new ChannelFutureListener {
+ def operationComplete(future: ChannelFuture) {
+ if (!future.isSuccess) {
+ log.error(future.getCause, "Failed to connect to epmd on %s:%s", host, port)
+ } else {
+ channel = future.getChannel
+ }
+ }
+ })
+ connectFuture
+ }
- val connectFuture = bootstrap.connect(new InetSocketAddress(host, port))
- val channel = connectFuture.awaitUninterruptibly.getChannel
- if(!connectFuture.isSuccess) {
- throw connectFuture.getCause
+ def connectBlocking: Epmd = {
+ val connectFuture = bootstrap.connect(new InetSocketAddress(host, port))
+ channel = connectFuture.awaitUninterruptibly.getChannel
+ this
}
def close {
channel.close
}
+ def connected = (channel != null)
+
def alive(portNo : Int, nodeName : String) : Option[Int] = {
+ if (!connected) {
+ log.error("'alive(%s, %s)' called before Epmd connected!", portNo, nodeName)
+ return None
+ }
+
channel.write(AliveReq(portNo,nodeName))
val response = handler.response.call.asInstanceOf[AliveResp]
if (response.result == 0) {
Some(response.creation)
} else {
- error("Epmd response was: " + response.result)
+ log.error("Epmd response was: " + response.result)
None
}
}
def lookupPort(nodeName : String) : Option[Int] = {
+ if(!connected) {
+ log.error("'lookupPort(%s)' called before Epmd connected!", nodeName)
+ return None
+ }
+
channel.write(PortPleaseReq(nodeName))
handler.response.call match {
case PortPleaseResp(portNo, _) => Some(portNo)
@@ -84,3 +169,4 @@ class Epmd(val host : String, val port : Int) {
}
}
+
diff --git a/src/main/scala/scalang/epmd/EpmdMessages.scala b/src/main/scala/scalang/epmd/EpmdMessages.scala
index 6bbed07..becf9e6 100644
--- a/src/main/scala/scalang/epmd/EpmdMessages.scala
+++ b/src/main/scala/scalang/epmd/EpmdMessages.scala
@@ -15,10 +15,12 @@
//
package scalang.epmd
+// Requests:
case class AliveReq(portNo : Int, nodeName : String)
case class AliveResp(result : Int, creation : Int)
+// Responses:
case class PortPleaseReq(nodeName : String)
case class PortPleaseError(result : Int)
diff --git a/src/test/scala/scalang/TestHelper.scala b/src/test/scala/scalang/TestHelper.scala
index 97f54f3..63075e6 100644
--- a/src/test/scala/scalang/TestHelper.scala
+++ b/src/test/scala/scalang/TestHelper.scala
@@ -29,9 +29,18 @@ object Escript {
object EpmdCmd {
def apply() : SysProcess = {
- val builder = new ProcessBuilder("epmd")
+ val osName = System.getProperty("os.name").toLowerCase
+ var builder : ProcessBuilder = null
+ if (!osName.contains("win")) {
+ builder = new ProcessBuilder("bash", "-c", "export PATH=" + formatPath + " && epmd")
+ } else {
+ builder = new ProcessBuilder("epmd")
+ }
builder.start
}
+
+ val additionalPaths = List("/usr/local/bin", "/usr/local/sbin")
+ def formatPath: String = additionalPaths.mkString(":") + ":" + System.getenv("PATH")
}
object ReadLine {
diff --git a/src/test/scala/scalang/epmd/EpmdSpec.scala b/src/test/scala/scalang/epmd/EpmdSpec.scala
index 3fc143d..1e0eafa 100644
--- a/src/test/scala/scalang/epmd/EpmdSpec.scala
+++ b/src/test/scala/scalang/epmd/EpmdSpec.scala
@@ -1,11 +1,12 @@
package scalang.epmd
import org.specs._
-import org.specs.runner._
+import mock.Mockito
import java.lang.{Process => SysProcess}
import scalang._
+import org.jboss.netty.channel.ChannelFuture
-class EpmdSpec extends SpecificationWithJUnit {
+class EpmdSpec extends SpecificationWithJUnit with Mockito {
"Epmd" should {
var proc : SysProcess = null
doBefore {
@@ -36,4 +37,47 @@ class EpmdSpec extends SpecificationWithJUnit {
epmdQuery.close
}
}
+
+ "Epmd object" should {
+ "return an Epmd directly" in {
+ val noConnectConfig = new EpmdConfig("localhost", Epmd.defaultPort, connectOnInit = false)
+ val epmd = Epmd(noConnectConfig)
+ epmd.connected must(be(false))
+ }
+
+ "connect with retries" in {
+ val epmd = mock[Epmd]
+ val future = mock[ChannelFuture]
+ epmd.connect.returns(future)
+
+ // Always returns true so that it's never polled
+ future.isDone.returns(true)
+
+ // 'connected' is checked before 'future.isSuccess' since there should be a polling step
+ // that's being skipped in each test.
+ epmd.connected
+ .returns(false)
+ .thenReturns(false)
+ .thenReturns(false)
+ .thenReturns(false)
+ .thenReturns(true)
+ future.isSuccess
+ .returns(false)
+ .thenReturns(false)
+ .thenReturns(false)
+ .thenReturns(true)
+
+ val retryCfg = new EpmdConfig(
+ "localhost",
+ Epmd.defaultPort,
+ connectOnInit = true,
+ connectionTimeout = Some(1),
+ retries = Some(10),
+ retryInterval = Some(1)
+ )
+ Epmd.connectWithRetries(epmd, retryCfg)
+
+ there was 4.times(epmd).connect
+ }
+ }
}