Immutable Session State in Scala

by simbo1905

In the servlet world the HttpSession object is a workhorse which few developers could live without. Recently I have been taking a hard look at Socko a minimal webserver which does not come with a session object.

“What!?” I hear you cry “Why would you use a webserver on the JVM that forgot to implement the J2EE standard HttpSession???”. Well Socko is a fresh look at what a JVM webserver needs to be in the age of REST, Websockets and Actors. If you need to write an Actor back-end exposed as websockets to a single-page/mobile app do you really need a HttpSession as a separate concept? Are not Actors sufficient to encapsulate user session state?

Furthermore wasn’t HTTP supposed to be stateless? Isn’t clustering the HttpSession using proprietary JEE features the toothache in scaling out servlets? Wouldn’t you rather have Let-It-Crash Actor error handling, clustering and failover for both business services and the user sessions?

If you go along with me so far then you may agree that we should be using an Actor for the users session state. Scaling that out across a cluster or making it durable and fault tolerant we can leave to deploy time actor configuration. Stick a mutable map into a UserSessionActor and “job done”.

Not so fast! Lets peak over the fence back at JEE and see what HttpSession gave us:

  1. Thread safe access. No problem. Our UserSessionActor will give us that.
  2. Deploy time clustering and fault tolerance. No problem. Akka (beyond the core jar) can give us that.
  3. Session timeout. Hummm. Not really a problem? To reduce memory footprint we are going to have to have our UserSessionActor schedule a message 2 minutes after the last seen user activity to stop itself. Only if no further user activity happens in the meanwhile.

Wow thats a lot of cpu cycles and complexity for something as simple as a session timeout. Also a mutable map is a bad idea; if the reference escapes the Actor on an outbound message you will introduce a concurrency bug.

What seems a much better idea is an immutable SessionState object which bumps the session expiry for a given key as it is accessed. During writes it can purge sessions which timed-out to keep the memory footprint down.

Below is my take on such a data structure adapted from Jamie Pullar’s blog.

import scala.util.Try
import scala.collection.immutable.Vector
import scala.collection.immutable.Map
import scala.util.Failure
import scala.util.Success
import scala.concurrent.duration.FiniteDuration

/**
 * SessionState[K, V] is an immutable key-value collection where
 * entries expire if they are not accessed or refreshed within
 * a given period.
 * Typically the three operators "+", "-" and "apply" manipulate
 * or access the data structure.
 * Adapted with changes from Jamie Pullar's SessionState at
 * http://higher-state.blogspot.co.uk/2013/02/session-state-in-scala1-immutable.html
 */
trait SessionState[K, V] {
  /**
   * Expiry time of entries in milliseconds
   * @return
   */
  def expiryInterval: Int

  /**
   * Returns the corresponding value for the key.
   * Does not refresh the session expiry time. Consider using apply
   * instead to increase the session time with mySessionState(someKey).
   * @param key
   * @param datetime current time in milliseconds
   * @return
   */
  def getValueNoRefresh(key: K)(implicit datetime: Long): Option[V]

  /**
   * Returns the corresponding value for the key and rejuvenates the
   * session increasing its expiry time to (datetime+expiryInterval).
   * Does rejuvenate the session expiry time. Overloaded as the "apply"
   * method so that you can invoke as mySessionState(someKey).
   * @param key
   * @param datetime current datetime in milliseconds
   * @return
   */
  def getValueWithRefresh(key: K)(implicit datetime: Long): (SessionState[K, V], Option[V])

  /**
   * Returns the corresponding value for the key and rejuvenates the
   * session increasing its expiry time to (datetime+ expiryInterval).
   * Does rejuvenate the session expiry time. Overloaded as the "apply"
   * method so that you can invoke as mySessionState(someKey).
   * @param sessionKey
   * @param datetime current datetime in milliseconds
   * @return
   */
  def apply(sessionKey: K)(implicit datetime: Long): (SessionState[K, V], Option[V]) =
    getValueWithRefresh(sessionKey)

  /**
   * Adds a session value and under a session key.
   * Does rejuvenate the session expiry time
   * @param sessionKey
   * @param value
   * @param datetime current datetime in milliseconds
   * @return New SessionState if successful, else SessionAlreadyExistsException
   */
  def put(key: K, value: V)(implicit datetime: Long): Try[SessionState[K, V]]

  /**
   * @see Invokes put(key: K, value: V)
   * @param sessionKey
   * @param value
   * @param datetime current datetime in milliseconds
   * @return New SessionState if successfull, else SessionAlreadyExistsException
   */
  def +(sessionKey: K, value: V)(implicit datetime: Long): Try[SessionState[K, V]] =
    put(sessionKey, value)

  /**
   * Removes the session value if found
   * @param key
   * @param datetime
   * @return New SessionState with the session key removed
   */
  def expire(sessionKey: K)(implicit datetime: Long): SessionState[K, V]

  /**
   * Removes the session value if found
   * @param key
   * @param datetime
   * @return New SessionState with the session key removed
   */
  def -(sessionKey: K)(implicit datetime: Long): SessionState[K, V] =
    expire(sessionKey)
}

object SessionState {
  def apply[K, V](duration: FiniteDuration): SessionState[K, V] =
    SessionStateInstance[K, V](duration.toMillis.toInt, Vector.empty, Map.empty)
}

private case class SessionStateInstance[K, V](expiryInterval: Int,
  sessionVector: Vector[(K, Long)], valuesWithExpiryMap: Map[K, (V, Long)])
  extends SessionState[K, V] {

  // vanilla access no refresh of expiry
  def getValueNoRefresh(sessionKey: K)(implicit datetime: Long): Option[V] =
    valuesWithExpiryMap.get(sessionKey) collect {
      case (value, expiry) if (expiry > datetime) => Some(value)
    } getOrElse (None)

  // gets the value and bumps the expiry by interval
  def getValueWithRefresh(sessionKey: K)(implicit datetime: Long): (SessionState[K, V], Option[V]) = {
    valuesWithExpiryMap.get(sessionKey) collect {
      case (value, expiry) if (expiry > datetime) =>
        (SessionStateInstance(this.expiryInterval, sessionVector,
          valuesWithExpiryMap + (sessionKey ->
            (value, datetime + this.expiryInterval))), Some(value))
    } getOrElse {
      (this, None)
    }
  }

  // puts a value into the session and expires any old session keys
  def put(sessionKey: K, value: V)(implicit datetime: Long): Try[SessionState[K, V]] =
    valuesWithExpiryMap.get(sessionKey) collect {
      case (value, expiry) if (expiry > datetime) =>
        Failure(SessionAlreadyExistsException)
    } getOrElse {
      val cleared = clearedExpiredSessions(datetime)
      Success(SessionStateInstance(this.expiryInterval,
        cleared.sessionVector :+ (sessionKey, datetime + this.expiryInterval),
        cleared.valuesWithExpiryMap + (sessionKey ->
          (value, datetime + this.expiryInterval))))
    }

  // fast delete can leave the queue and map out of sync which will be fixed up on next clear operation
  def expire(sessionKey: K)(implicit datetime: Long): SessionState[K, V] = {
    val cleared = clearedExpiredSessions(datetime)
    if (cleared.valuesWithExpiryMap.contains(sessionKey))
      SessionStateInstance(this.expiryInterval,
        cleared.sessionVector,
        cleared.valuesWithExpiryMap - sessionKey)
    else cleared
  }

  // used for unit testing
  def size = {
    valuesWithExpiryMap.size
  }

  private def clearedExpiredSessions(datetime: Long): SessionStateInstance[K, V] =
    clearedExpiredSessions(datetime, sessionVector, valuesWithExpiryMap)

  // forward scans the vector dropping expired values and puts it back
  // into the same state as the updated map. halts when it finds unexpired
  // values
  private def clearedExpiredSessions(now: Long, sessionQueue: Vector[(K, Long)], valuesWithExpiryMap: Map[K, (V, Long)]): SessionStateInstance[K, V] = {
    sessionQueue.headOption collect {
      // if we have an expired key
      case (key, end) if (end < now) =>
        // double check with map
        valuesWithExpiryMap.get(key) map {
          case (_, expiry) if (expiry < now) =>
            // drop the expired value
            clearedExpiredSessions(now, sessionQueue.drop(1), valuesWithExpiryMap - key)
          case (_, expiry) =>
            // out of order extended session forces a sort
            clearedExpiredSessions(now, (sessionQueue.drop(1) :+ (key, expiry)).sortBy(_._2), valuesWithExpiryMap)
        } getOrElse {
          // if it was not found in the map drop it from the front and check the next value
          clearedExpiredSessions(now, sessionQueue.drop(1), valuesWithExpiryMap)
        }
    } getOrElse {
      // no more expired keys at the front of the vector stop recursion
      SessionStateInstance(this.expiryInterval, sessionQueue, valuesWithExpiryMap)
    }
  }
}

case object SessionAlreadyExistsException extends Exception("Session key already exists")

case object SessionNotFound extends Exception("Session key not found")

Access and deletion are fast as they only use the map. The neat things about Jamie’s algorithm is the lazy purge of expired values which is run when new values are put into the map. The “put” appends the expiry to the back of the vector. The purge drops expired values from the front of the vector and deletes the entry from the map if it is still there. If an “apply” has extended the expiry of the entry due to user access then the purge drops the old expiry from the front and appends the correct value to the back. Eventually any expired entries will be deleted within twice the expiry interval. This is done without taking the cost of maintaining exact ordering to purge immediately.

In the next post we will put this all into practice to perform openid verification in a socko web app by wrapping a SessionState in an Akka Actor.

End.

Advertisements