diff --git a/pom.xml b/pom.xml index 99ab099..208a13a 100644 --- a/pom.xml +++ b/pom.xml @@ -37,6 +37,12 @@ 4.8.2 test + + org.mockito + mockito-all + 1.9.0 + test + org.scala-tools.testing specs_${scala.version} diff --git a/src/main/scala/scalang/epmd/Epmd.scala b/src/main/scala/scalang/epmd/Epmd.scala index df994c6..90580d2 100644 --- a/src/main/scala/scalang/epmd/Epmd.scala +++ b/src/main/scala/scalang/epmd/Epmd.scala @@ -21,30 +21,81 @@ import netty.bootstrap._ import netty.channel._ import socket.nio._ import overlock.threadpool._ +import com.codahale.logula.Logging -object Epmd { +case class EpmdConfig( + host: String, + port: Int, + connectOnInit: Boolean = true, + retries: Option[Int] = Some(Epmd.defaultRetries), + retryInterval: Option[Double] = None, + connectionTimeout: Option[Int] = None +) + +object Epmd extends Logging { + private val futurePollingInterval = 100 /* nanos */ val defaultPort = 4369 + val defaultRetries = 5 + val defaultRetryInterval = 2.0 /*seconds*/ + lazy val bossPool = ThreadPool.instrumentedElastic("scalang.epmd", "boss", 1, 20) lazy val workerPool = ThreadPool.instrumentedElastic("scalang.epmd", "worker", 1, 20) - def apply(host : String) : Epmd = { - val port = Option(System.getenv("ERL_EPMD_PORT")).map(_.toInt).getOrElse(defaultPort) - new Epmd(host, port) + def apply(host: String, port: Option[Int] = None): Epmd = { + val epmdPort = port match { + case Some(p) => p + case None => Option(System.getenv("ERL_EPMD_PORT")).map(_.toInt).getOrElse(defaultPort) + } + Epmd(new EpmdConfig(host, epmdPort)) } - def apply(host : String, port : Int) : Epmd = { - new Epmd(host, port) + def apply(cfg: EpmdConfig): Epmd = { + val epmd = new Epmd(cfg.host, cfg.port, cfg.connectionTimeout) + if (cfg.connectOnInit) { + connectWithRetries(epmd, cfg) + } + epmd + } + + def connectWithRetries(epmd: Epmd, cfg: EpmdConfig) { + var future = epmd.connect + while (!future.isDone) { + Thread.sleep(futurePollingInterval) + } + + if (cfg.retries.isDefined) { + val retries = cfg.retries.get + + val retryInterval = cfg.retryInterval.getOrElse(defaultRetryInterval) + val retryIntervalMillis = (retryInterval * 1000.0).toInt + + var numRetries = 0 + while (!epmd.connected && numRetries < retries) { + // Retry the connection + if (!future.isSuccess) { + log.warn("epmd connection failed. Retrying in %.1f seconds", retryInterval) + Thread.sleep(retryIntervalMillis) + future = epmd.connect + numRetries += 1 + } + // Poll the future + while (!future.isDone) { + Thread.sleep(futurePollingInterval) + } + } + } } } -class Epmd(val host : String, val port : Int) { +class Epmd(val host : String, val port : Int, val defaultTimeout: Option[Int] = None) extends Logging { + var channel: Channel = null + val handler = new EpmdHandler + val bootstrap = new ClientBootstrap( new NioClientSocketChannelFactory( Epmd.bossPool, Epmd.workerPool)) - val handler = new EpmdHandler - bootstrap.setPipelineFactory(new ChannelPipelineFactory { def getPipeline : ChannelPipeline = { Channels.pipeline( @@ -53,29 +104,63 @@ class Epmd(val host : String, val port : Int) { handler) } }) + setTimeout(defaultTimeout) + + + def setTimeout(timeout: Option[Int]) { + if (timeout.isDefined) { + bootstrap.setOption("connectTimeoutMillis", timeout.get * 1000) + } + } + + def connect: ChannelFuture = { + val connectFuture = bootstrap.connect(new InetSocketAddress(host, port)) + connectFuture.addListener(new ChannelFutureListener { + def operationComplete(future: ChannelFuture) { + if (!future.isSuccess) { + log.error(future.getCause, "Failed to connect to epmd on %s:%s", host, port) + } else { + channel = future.getChannel + } + } + }) + connectFuture + } - val connectFuture = bootstrap.connect(new InetSocketAddress(host, port)) - val channel = connectFuture.awaitUninterruptibly.getChannel - if(!connectFuture.isSuccess) { - throw connectFuture.getCause + def connectBlocking: Epmd = { + val connectFuture = bootstrap.connect(new InetSocketAddress(host, port)) + channel = connectFuture.awaitUninterruptibly.getChannel + this } def close { channel.close } + def connected = (channel != null) + def alive(portNo : Int, nodeName : String) : Option[Int] = { + if (!connected) { + log.error("'alive(%s, %s)' called before Epmd connected!", portNo, nodeName) + return None + } + channel.write(AliveReq(portNo,nodeName)) val response = handler.response.call.asInstanceOf[AliveResp] if (response.result == 0) { Some(response.creation) } else { - error("Epmd response was: " + response.result) + log.error("Epmd response was: " + response.result) None } } def lookupPort(nodeName : String) : Option[Int] = { + if(!connected) { + log.error("'lookupPort(%s)' called before Epmd connected!", nodeName) + return None + } + channel.write(PortPleaseReq(nodeName)) handler.response.call match { case PortPleaseResp(portNo, _) => Some(portNo) @@ -84,3 +169,4 @@ class Epmd(val host : String, val port : Int) { } } + diff --git a/src/main/scala/scalang/epmd/EpmdMessages.scala b/src/main/scala/scalang/epmd/EpmdMessages.scala index 6bbed07..becf9e6 100644 --- a/src/main/scala/scalang/epmd/EpmdMessages.scala +++ b/src/main/scala/scalang/epmd/EpmdMessages.scala @@ -15,10 +15,12 @@ // package scalang.epmd +// Requests: case class AliveReq(portNo : Int, nodeName : String) case class AliveResp(result : Int, creation : Int) +// Responses: case class PortPleaseReq(nodeName : String) case class PortPleaseError(result : Int) diff --git a/src/test/scala/scalang/TestHelper.scala b/src/test/scala/scalang/TestHelper.scala index 97f54f3..63075e6 100644 --- a/src/test/scala/scalang/TestHelper.scala +++ b/src/test/scala/scalang/TestHelper.scala @@ -29,9 +29,18 @@ object Escript { object EpmdCmd { def apply() : SysProcess = { - val builder = new ProcessBuilder("epmd") + val osName = System.getProperty("os.name").toLowerCase + var builder : ProcessBuilder = null + if (!osName.contains("win")) { + builder = new ProcessBuilder("bash", "-c", "export PATH=" + formatPath + " && epmd") + } else { + builder = new ProcessBuilder("epmd") + } builder.start } + + val additionalPaths = List("/usr/local/bin", "/usr/local/sbin") + def formatPath: String = additionalPaths.mkString(":") + ":" + System.getenv("PATH") } object ReadLine { diff --git a/src/test/scala/scalang/epmd/EpmdSpec.scala b/src/test/scala/scalang/epmd/EpmdSpec.scala index 3fc143d..1e0eafa 100644 --- a/src/test/scala/scalang/epmd/EpmdSpec.scala +++ b/src/test/scala/scalang/epmd/EpmdSpec.scala @@ -1,11 +1,12 @@ package scalang.epmd import org.specs._ -import org.specs.runner._ +import mock.Mockito import java.lang.{Process => SysProcess} import scalang._ +import org.jboss.netty.channel.ChannelFuture -class EpmdSpec extends SpecificationWithJUnit { +class EpmdSpec extends SpecificationWithJUnit with Mockito { "Epmd" should { var proc : SysProcess = null doBefore { @@ -36,4 +37,47 @@ class EpmdSpec extends SpecificationWithJUnit { epmdQuery.close } } + + "Epmd object" should { + "return an Epmd directly" in { + val noConnectConfig = new EpmdConfig("localhost", Epmd.defaultPort, connectOnInit = false) + val epmd = Epmd(noConnectConfig) + epmd.connected must(be(false)) + } + + "connect with retries" in { + val epmd = mock[Epmd] + val future = mock[ChannelFuture] + epmd.connect.returns(future) + + // Always returns true so that it's never polled + future.isDone.returns(true) + + // 'connected' is checked before 'future.isSuccess' since there should be a polling step + // that's being skipped in each test. + epmd.connected + .returns(false) + .thenReturns(false) + .thenReturns(false) + .thenReturns(false) + .thenReturns(true) + future.isSuccess + .returns(false) + .thenReturns(false) + .thenReturns(false) + .thenReturns(true) + + val retryCfg = new EpmdConfig( + "localhost", + Epmd.defaultPort, + connectOnInit = true, + connectionTimeout = Some(1), + retries = Some(10), + retryInterval = Some(1) + ) + Epmd.connectWithRetries(epmd, retryCfg) + + there was 4.times(epmd).connect + } + } }