I am currently trying to modify the source code of simple DistBelief framework implemented using Akka Actors. The original source code is here: http://alexminnaar.com/implementing-the-distbelief-deep-neural-network-training-framework-with-akka.html . Original implementation is based on just Akka Actors, but I want to extend it to distributed mode. I think Akka-Cluster-Sharding is the correct option for this task. But I am wondering where to properly handle incoming messages, in receive() method, or in extractShardId() & extractEntityId() in an actor class (for example for ParameterShard Actor, you can see full source code in the above given link). Akka's offical docs say: *The extractEntityId and extractShardId are two application specific functions to extract the entity identifier and the shard identifier from incoming messages.
object ParameterShard {
case class ParameterRequest(dataShardId: Int, layerId: Int)
case class LatestParameters(weights: DenseMatrix[Double])
}
class ParamServer(shardId: Int,
numberOfShards: Int,
learningRate: Double,
initialWeight: LayerWeight) extends Actor with ActorLogging {
val shardName: String = "ParamServer"
val extractEntityId: ShardRegion.ExtractEntityId = {
//case ps: ParameterRequest => (ps.dataShardId.toString, ps)
}
val extractShardId: ShardRegion.ExtractShardId = {
//case ps: ParameterRequest => (ps.dataShardId % numberOfShards).toString
}
//weights initialize randomly
var latestParameter: LayerWeight = initialWeight
def receive = {
//A layer corresponding to this shardId in some model replica has requested the latest version of the parameters.
case ParameterRequest(shardId, layerId) => {
log.info(s"layer ${layerId} weights read by model replica ${shardId}")
context.sender() ! LatestParameters(latestParameter)
}
/*
A layer corresponding to this shardId in some model replica has computed a gradient, so we must update our
parameters according to this gradient.
*/
case Gradient(g, replicaId, layerId) => {
log.info(s"layer ${layerId} weights updated by model replica ${replicaId}")
latestParameter = latestParameter + g.t * learningRate
}
}
}
extractShardId
is called to extract information to decide with shard is responsible from processing that message so route that message to corresponding shard. AndextractEntityId
is called to decide which actor will process that message. So you just need to implement your application specific logic according to that and akka will handle it – Mustafa Simav