Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

first draft of OxApp with extension companion traits #157

Merged
merged 8 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions core/src/main/scala/ox/OxApp.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
package ox

import scala.util.boundary.*
import scala.util.control.NonFatal

enum ExitCode(val code: Int) {
case Success extends ExitCode(0)
case Failure(exitCode: Int = 1) extends ExitCode(exitCode)
}

trait OxApp {

import OxApp.AppSettings

protected def settings: AppSettings = AppSettings.defaults

final def main(args: Array[String]): Unit =
unsupervised {
val cancellableMainFork = forkCancellable(supervised(handleRun(args.toVector)))

val interruptThread = new Thread(() => {
cancellableMainFork.cancel()
()
})

interruptThread.setName("ox-interrupt-hook")

mountShutdownHook(interruptThread)

cancellableMainFork.joinEither() match
case Left(iex: InterruptedException) => exit(settings.gracefulShutdownExitCode)
case Left(fatalErr) => throw fatalErr
case Right(exitCode) => exit(exitCode)
}

/** For testing - trapping System.exit is impossible due to SecurityManager removal so it's just overrideable in tests.
*
* @param code
* Int exit code
*/
private[ox] def exit(exitCode: ExitCode): Unit = System.exit(exitCode.code)

/** For testing - allows to trigger shutdown hook without actually stopping the jvm.
*
* @param thread
* Thread
*/
private[ox] def mountShutdownHook(thread: Thread): Unit =
try Runtime.getRuntime.addShutdownHook(thread)
catch case _: IllegalStateException => ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't we re-throw the exception if mounding the hook fails?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, it's thrown only if we're already in shut down of the vm so there's nothing more to do

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for reference: CE IOApp only shuts down it's thread pools when it gets ISE here


/** For testing - allows to capture the stack trace printed to the console
*
* @param t
* Throwable
*/
private[ox] def printStackTrace(t: Throwable): Unit = t.printStackTrace()

private[OxApp] final def handleRun(args: Vector[String])(using Ox): ExitCode =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is Ox needed here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah we need it for run, ok

try run(args)
catch
case NonFatal(err) =>
printStackTrace(err)
ExitCode.Failure()

def run(args: Vector[String])(using Ox): ExitCode

}

object OxApp {

case class AppSettings(
/** This value is returned to the operating system as the exit code when the app receives SIGINT and shuts itself down gracefully.
* Default value is `ExitCode.Success` (0). JVM itself returns code `130` when it receives `SIGINT`.
*/
gracefulShutdownExitCode: ExitCode = ExitCode.Success

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I generally prefer to avoid default parameters for case classes - they are hard to change later. But doesn't seem likely to change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, since we have the default instance let's move this there

)

object AppSettings {
lazy val defaults: AppSettings = AppSettings()
}

/** Simple variant of OxApp does not pass command line arguments and exits with exit code 0 if no exceptions were thrown.
*/
trait Simple extends OxApp {
override final def run(args: Vector[String])(using Ox): ExitCode =
run
ExitCode.Success

def run(using Ox): Unit
}

/** WithErrorMode variant of OxApp allows to specify what kind of error handling for the main function should be used. Base trait for
* integrations.
*
* @tparam E
* Error type
* @tparam F
* wrapper type for given ErrorMode
*/
trait WithErrorMode[E, F[_]](em: ErrorMode[E, F]) extends OxApp {
override final def run(args: Vector[String])(using ox: Ox): ExitCode =
val result = runWithErrors(args)
if em.isError(result) then handleError(em.getError(result))
else ExitCode.Success

/** Allows implementor of this trait to translate an error that app finished with into a concrete ExitCode.
*
* @param e
* E Error type
* @return
* ExitCode
*/
def handleError(e: E): ExitCode

/** This template method is to be implemented by abstract classes that add integration for particular error handling data structure of
* type F[_].
*
* @param args
* List[String]
* @return
* F[ExitCode]
*/
def runWithErrors(args: Vector[String])(using Ox): F[ExitCode]
}

/** WithEitherErrors variant of OxApp integrates OxApp with an `either` block and allows for usage of `.ok()` combinators in the body of
* the main function.
*
* @tparam E
* Error type
*/
abstract class WithEitherErrors[E] extends WithErrorMode(EitherMode[E]()) {

type EitherError[Err] = Label[Either[Err, ExitCode]]

override final def runWithErrors(args: Vector[String])(using ox: Ox): Either[E, ExitCode] =
either[E, ExitCode](label ?=> run(args)(using ox, label))

def run(args: Vector[String])(using Ox, EitherError[E]): ExitCode
}

}
255 changes: 255 additions & 0 deletions core/src/test/scala/ox/OxAppTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
package ox

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import ox.ExitCode.*

import java.io.{PrintWriter, StringWriter}
import java.util.concurrent.CountDownLatch
import scala.util.boundary.*
import scala.concurrent.duration.*

class OxAppTest extends AnyFlatSpec with Matchers:

"OxApp" should "work in happy case" in {
var ec = Int.MinValue

object Main1 extends OxApp:
override def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

def run(args: Vector[String])(using Ox): ExitCode = Success

Main1.main(Array.empty)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also assert that the result is 0 here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but main is Unit, we check ec to be equal 0 below


ec shouldEqual 0
}

"OxApp" should "work in interrupted case" in {
var ec = Int.MinValue
val shutdownLatch = CountDownLatch(1)

object Main2 extends OxApp:
override private[ox] def mountShutdownHook(thread: Thread): Unit =
val damoclesThread = Thread(() => {
shutdownLatch.await()
thread.start()
thread.join()
})

damoclesThread.start()

override private[ox] def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

def run(args: Vector[String])(using Ox): ExitCode =
forever: // will never finish
sleep(10.millis)

Success

supervised:
fork(Main2.main(Array.empty))
sleep(10.millis)
shutdownLatch.countDown()

ec shouldEqual 0
}

"OxApp" should "work in failed case" in {
var ec = Int.MinValue
var stackTrace = ""

object Main3 extends OxApp:
override def run(args: Vector[String])(using Ox): ExitCode =
Failure(23)

override private[ox] def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

Main3.main(Array.empty)

ec shouldEqual 23

ec = Int.MinValue

object Main4 extends OxApp:
override def run(args: Vector[String])(using Ox): ExitCode =
throw Exception("oh no")

override private[ox] def printStackTrace(t: Throwable): Unit =
val sw = StringWriter()
val pw = PrintWriter(sw)
t.printStackTrace(pw)
stackTrace = sw.toString

override private[ox] def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

Main4.main(Array.empty)

ec shouldEqual 1
assert(stackTrace.contains("oh no"))
}

"OxApp.Simple" should "work in happy case" in {
var ec = Int.MinValue

object Main5 extends OxApp.Simple:
override def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

override def run(using Ox): Unit = ()

Main5.main(Array.empty)

ec shouldEqual 0
}

"OxApp.Simple" should "work in interrupted case" in {
var ec = Int.MinValue
val shutdownLatch = CountDownLatch(1)

object Main6 extends OxApp.Simple:
override private[ox] def mountShutdownHook(thread: Thread): Unit =
val damoclesThread = Thread(() => {
shutdownLatch.await()
thread.start()
thread.join()
})

damoclesThread.start()

override def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

override def run(using Ox): Unit =
forever:
sleep(10.millis)

supervised:
fork(Main6.main(Array.empty))
sleep(10.millis)
shutdownLatch.countDown()

ec shouldEqual 0
}

"OxApp.Simple" should "work in failed case" in {
var ec = Int.MinValue
var stackTrace = ""

object Main7 extends OxApp.Simple:
override def run(using Ox): Unit = throw Exception("oh no")

override private[ox] def printStackTrace(t: Throwable): Unit =
val sw = StringWriter()
val pw = PrintWriter(sw)
t.printStackTrace(pw)
stackTrace = sw.toString

override private[ox] def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

Main7.main(Array.empty)

ec shouldEqual 1
assert(stackTrace.contains("oh no"))
}

case class FunException(code: Int) extends Exception("")

import ox.either.*

"OxApp.WithErrors" should "work in happy case" in {
var ec = Int.MinValue
val errOrEc: Either[FunException, ExitCode] = Right(Success)

object Main8 extends OxApp.WithEitherErrors[FunException]:
override def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

override def handleError(e: FunException): ExitCode = Failure(e.code)

override def run(args: Vector[String])(using Ox, EitherError[FunException]): ExitCode =
errOrEc.ok()

Main8.main(Array.empty)

ec shouldEqual 0
}

"OxApp.WithErrors" should "work in interrupted case" in {
var ec = Int.MinValue
val shutdownLatch = CountDownLatch(1)
val errOrEc: Either[FunException, ExitCode] = Left(FunException(23))

object Main9 extends OxApp.WithEitherErrors[FunException]:
override private[ox] def mountShutdownHook(thread: Thread): Unit =
val damoclesThread = Thread(() => {
shutdownLatch.await()
thread.start()
thread.join()
})

damoclesThread.start()

override def handleError(e: FunException): ExitCode = Failure(e.code)

override private[ox] def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

override def run(args: Vector[String])(using Ox, EitherError[FunException]): ExitCode =
forever: // will never finish
sleep(10.millis)

errOrEc.ok()

supervised:
fork(Main9.main(Array.empty))
sleep(10.millis)
shutdownLatch.countDown()

ec shouldEqual 0
}

"OxApp.WithErrors" should "work in failed case" in {
var ec = Int.MinValue
val errOrEc: Either[FunException, ExitCode] = Left(FunException(23))
var stackTrace = ""

object Main10 extends OxApp.WithEitherErrors[FunException]:
override def run(args: Vector[String])(using Ox, EitherError[FunException]): ExitCode =
errOrEc.ok()

override private[ox] def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

override def handleError(e: FunException): ExitCode = Failure(e.code)

Main10.main(Array.empty)

ec shouldEqual 23

ec = Int.MinValue

object Main11 extends OxApp.WithEitherErrors[FunException]:
override def run(args: Vector[String])(using Ox, EitherError[FunException]): ExitCode =
throw Exception("oh no")

override private[ox] def exit(exitCode: ExitCode): Unit =
ec = exitCode.code

override private[ox] def printStackTrace(t: Throwable): Unit =
val sw = StringWriter()
val pw = PrintWriter(sw)
t.printStackTrace(pw)
stackTrace = sw.toString

override def handleError(e: FunException): ExitCode = ??? // should not get called!

Main11.main(Array.empty)

ec shouldEqual 1
assert(stackTrace.contains("oh no"))
}
Loading
Loading