Sharing my solution. For simplicity, I am using integers as documents, document processing is done when integer reaches 8;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static java.util.stream.Collectors.toList;
import static org.junit.Assert.fail;
import java.util.List;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executors;
import java.util.stream.IntStream;
import akka.NotUsed;
import akka.actor.ActorSystem;
import akka.stream.ActorMaterializer;
import akka.stream.ClosedShape;
import akka.stream.DelayOverflowStrategy;
import akka.stream.Graph;
import akka.stream.KillSwitches;
import akka.stream.OverflowStrategy;
import akka.stream.SharedKillSwitch;
import akka.stream.UniformFanInShape;
import akka.stream.UniformFanOutShape;
import akka.stream.javadsl.Broadcast;
import akka.stream.javadsl.Flow;
import akka.stream.javadsl.GraphDSL;
import akka.stream.javadsl.Merge;
import akka.stream.javadsl.RunnableGraph;
import akka.stream.javadsl.Sink;
import akka.stream.javadsl.Source;
import scala.concurrent.duration.Duration;
import scala.concurrent.duration.FiniteDuration;
import org.assertj.core.api.SoftAssertions;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
public class BulkRetryFlowProviderTest {
private ActorSystem actorSystem;
private SharedKillSwitch killSwitch;
@Before
public void setUp() {
actorSystem = ActorSystem.create();
killSwitch = KillSwitches.shared("my-kill-switch");
}
@After
public void tearDown() {
actorSystem.terminate();
}
@Test
public void test() throws Exception {
Source<Integer, NotUsed> source = Source.from(IntStream.range(0, 5)
.boxed()
.collect(toList()));
// each input creates 5 items to process
Flow<Integer, Integer, NotUsed> createDocuments = Flow.of(Integer.class)
.mapConcat(i -> IntStream.range(0, 5).boxed().collect(toList()));
// buffer items and do bulk processing
Flow<Integer, Integer, NotUsed> bulkProcess = Flow.of(Integer.class)
.groupedWithin(50, Duration.create(10, MILLISECONDS))
.mapConcat(integers -> integers.stream()
.map(i -> i + 1)
.collect(toList()));
// items are ready
Flow<Integer, Integer, NotUsed> filterProcessed = Flow.of(Integer.class)
.filter(i -> i > 7)
.map(i -> {
System.out.println("Done: " + i);
return i;
});
// items should be processed again
Flow<Integer, Integer, NotUsed> filterRecoverable = Flow.of(Integer.class)
.filter(i -> i <= 7);
Flow<Integer, Integer, NotUsed> bufferRetry = Flow.of(Integer.class)
.buffer(3, OverflowStrategy.backpressure())
.delay(FiniteDuration.apply(10, MILLISECONDS), DelayOverflowStrategy.backpressure());
Graph<ClosedShape, CompletionStage<List<Integer>>> graph = GraphDSL.create(Sink.seq(), (builder, out) -> {
UniformFanOutShape<Integer, Integer> broadcast =
builder.add(Broadcast.create(2));
UniformFanInShape<Integer, Integer> merge = builder.add(Merge.create(2));
builder
.from(builder.add(source).out())
.via(builder.add(createDocuments))
.viaFanIn(merge)
.via(builder.add(bulkProcess))
.viaFanOut(broadcast)
.via(builder.add(filterProcessed))
.via(builder.add(killSwitch.flow()))
.to(out);
builder
.from(broadcast)
.via(builder.add(filterRecoverable))
.via(builder.add(bufferRetry))
.toFanIn(merge);
return ClosedShape.getInstance();
});
CompletionStage<List<Integer>> completionStage = RunnableGraph
.fromGraph(graph)
.run(ActorMaterializer.create(actorSystem)).exceptionally(e -> {
fail("Stream filed: " + e);
return null;
});
// give it some time to complete
Executors.newCachedThreadPool().submit(() -> {
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
e.printStackTrace();
}
killSwitch.shutdown();
});
List<Integer> result = completionStage.toCompletableFuture().get(10000, SECONDS);
SoftAssertions.assertSoftly(softly -> {
softly
.assertThat(result)
.hasSize(25);
softly
.assertThat(result)
.allMatch(i -> i == 8);
});
}
}