I have been experimenting with spark and mllib to train a word2vec model but I don't seem to be getting the performance benefits of distributed machine learning on large datasets. My understanding is that if I have w workers, then, if I create an RDD with n number of partitions where n>w and I try to create a Word2Vec Model by calling the fit function of Word2Vec with the RDD as parameter then spark would distribute the data uniformly to train separate word2vec models on these w workers and use some sort of a reducer function at the end to create a single output model from these w models. This would reduce the computation time as rather than 1 chunk, w chunks of data will be processed simultaneously. The trade-off would be that some loss of precision might happen depending upon the reducer function used at the end. Does Word2Vec in Spark actually work this way or not? I might need to play with the configurable parameters if this is indeed the case.
EDIT
Adding the reason behind asking this question. I ran java spark word2vec code on 10 worker machines and set suitable values for executor-memory, driver memory and num-executors, after going though the documentation, for a 2.5gb input text file which was mapped to rdd partitions which were then used as training data for an mllib word2vec model. The training part took multiple hours. The number of worker nodes doesn't seem to be having much of an effect on the training time. The same code runs successfully on smaller data files (of the order of 10s of MBs)
Code
SparkConf conf = new SparkConf().setAppName("SampleWord2Vec");
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
conf.registerKryoClasses(new Class[]{String.class, List.class});
JavaSparkContext jsc = new JavaSparkContext(conf);
JavaRDD<List<String>> jrdd = jsc.textFile(inputFile, 3).map(new Function<String, List<String>>(){
@Override
public List<String> call(String s) throws Exception {
return Arrays.asList(s.split(","));
}
});
jrdd.persist(StorageLevel.MEMORY_AND_DISK());
Word2Vec word2Vec = new Word2Vec()
.setWindowSize(20)
.setMinCount(20);
Word2VecModel model = word2Vec.fit(jrdd);
jrdd.unpersist(false);
model.save(jsc.sc(), outputfile);
jsc.stop();
jsc.close();