6
votes

Here is a simple code example to illustrate my question:

case class Record( key: String, value: Int )

object Job extends App
{
  val env = StreamExecutionEnvironment.getExecutionEnvironment
  val data = env.fromElements( Record("01",1), Record("02",2), Record("03",3), Record("04",4), Record("05",5) )
  val step1 = data.filter( record => record.value % 3 != 0  ) // introduces some data loss
  val step2 = data.map( r => Record( r.key, r.value * 2 ) )
  val step3 = data.map( r => Record( r.key, r.value * 3 ) )
  val merged = step1.union( step2, step3 )
  val keyed = merged.keyBy(0)
  val windowed = keyed.countWindow( 3 )
  val summed = windowed.sum( 1 )
  summed.print()
  env.execute("test")
}

This produces the following result:

Record(01,6)
Record(02,12)
Record(04,24)
Record(05,30)

As expected, no result is produced for key "03" because the count window expects 3 elements and only two are present in the stream.

What I would like is some kind of count window with timeout so that, after a certain timeout, if the number of elements expected by the count window is not reached, a partial result is produced with the existing elements.

With this behavior, in my example, a Record(03,15) would be produced when the timeout is reached.

3

3 Answers

6
votes

You could also do this with a custom window Trigger that fires either when the count has been reached or when the timeout expires -- effectively blending the built-in CountTrigger and EventTimeTrigger.

6
votes

I have followed both David's and NIrav's approaches and here are the results.

1) Using a custom trigger:

Here I have reversed my initial logic. Instead of using a 'count window', I use a 'time window' with a duration corresponding to the timeout and followed by a trigger that fires when all the elements have been processed.

case class Record( key: String, value: Int )

object Job extends App
{
  val env = StreamExecutionEnvironment.getExecutionEnvironment
  val data = env.fromElements( Record("01",1), Record("02",2), Record("03",3), Record("04",4), Record("05",5) )
  val step1 = data.filter( record => record.value % 3 != 0  ) // introduces some data loss
  val step2 = data.map( r => Record( r.key, r.value * 2 ) )
  val step3 = data.map( r => Record( r.key, r.value * 3 ) )
  val merged = step1.union( step2, step3 )
  val keyed = merged.keyBy(0)
  val windowed = keyed.timeWindow( Time.milliseconds( 50 ) )
  val triggered = windowed.trigger( new CountTriggerWithTimeout( 3, env.getStreamTimeCharacteristic ) )
  val summed = triggered.sum( 1 )
  summed.print()
  env.execute("test")
}

And here is the trigger code:

import org.apache.flink.annotation.PublicEvolving
import org.apache.flink.api.common.functions.ReduceFunction
import org.apache.flink.api.common.functions.RuntimeContext
import org.apache.flink.api.common.state.ReducingState
import org.apache.flink.api.common.state.ReducingStateDescriptor
import org.apache.flink.api.common.typeutils.base.LongSerializer
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.windowing.triggers._
import org.apache.flink.streaming.api.windowing.triggers.Trigger.TriggerContext
import org.apache.flink.streaming.api.windowing.windows.TimeWindow

/**
 * A trigger that fires when the count of elements in a pane reaches the given count or a 
 * timeout is reached whatever happens first.
 */
class CountTriggerWithTimeout[W <: TimeWindow](maxCount: Long, timeCharacteristic: TimeCharacteristic) extends Trigger[Object,W] 
{
  private val countState: ReducingStateDescriptor[java.lang.Long] = new ReducingStateDescriptor[java.lang.Long]( "count", new Sum(), LongSerializer.INSTANCE)

  override def onElement(element: Object, timestamp: Long, window: W, ctx: TriggerContext): TriggerResult = 
  {
      val count: ReducingState[java.lang.Long] = ctx.getPartitionedState(countState)
      count.add( 1L )
      if ( count.get >= maxCount || timestamp >= window.getEnd ) TriggerResult.FIRE_AND_PURGE else TriggerResult.CONTINUE
  }

  override def onProcessingTime(time: Long, window: W, ctx: TriggerContext): TriggerResult = 
  {
      if (timeCharacteristic == TimeCharacteristic.EventTime) TriggerResult.CONTINUE else
      {
          if ( time >= window.getEnd ) TriggerResult.CONTINUE else TriggerResult.FIRE_AND_PURGE
      }
  }

  override def onEventTime(time: Long, window: W, ctx: TriggerContext): TriggerResult = 
  {
      if (timeCharacteristic == TimeCharacteristic.ProcessingTime) TriggerResult.CONTINUE else
      {
          if ( time >= window.getEnd ) TriggerResult.CONTINUE else TriggerResult.FIRE_AND_PURGE
      }
  }

  override def clear(window: W, ctx: TriggerContext): Unit = 
  {
          ctx.getPartitionedState( countState ).clear
    }

    class Sum extends ReduceFunction[java.lang.Long] 
  {
        def reduce(value1: java.lang.Long, value2: java.lang.Long): java.lang.Long = value1 + value2
  }
}

2) Using a process function:

case class Record( key: String, value: Int )

object Job extends App
{
  val env = StreamExecutionEnvironment.getExecutionEnvironment
  env.setStreamTimeCharacteristic( TimeCharacteristic.IngestionTime )
  val data = env.fromElements( Record("01",1), Record("02",2), Record("03",3), Record("04",4), Record("05",5) )
  val step1 = data.filter( record => record.value % 3 != 0  ) // introduces some data loss
  val step2 = data.map( r => Record( r.key, r.value * 2 ) )
  val step3 = data.map( r => Record( r.key, r.value * 3 ) )
  val merged = step1.union( step2, step3 )
  val keyed = merged.keyBy(0)
  val processed = keyed.process( new TimeCountWindowProcessFunction( 3, 100 ) )
  processed.print()
  env.execute("test")
}

With all the logic (i.e., windowing, triggering, and summing) going into the function:

import org.apache.flink.streaming.api.functions._
import org.apache.flink.util._
import org.apache.flink.api.common.state._

case class Status( count: Int, key: String, value: Long )

class TimeCountWindowProcessFunction( count: Long, windowSize: Long ) extends ProcessFunction[Record,Record] 
{
    lazy val state: ValueState[Status] = getRuntimeContext
      .getState(new ValueStateDescriptor[Status]("state", classOf[Status]))

    override def processElement( input: Record, ctx: ProcessFunction[Record,Record]#Context, out: Collector[Record] ): Unit =
    {
        val updated: Status = Option( state.value ) match {
            case None => {
                ctx.timerService().registerEventTimeTimer( ctx.timestamp + windowSize )
                Status( 1, input.key, input.value )
            }
            case Some( current ) => Status( current.count + 1, input.key, input.value + current.value )    
        }
        if ( updated.count == count ) 
        {
            out.collect( Record( input.key, updated.value ) )
            state.clear
        }
        else
        {
            state.update( updated )  
        }        
    }

    override def onTimer( timestamp: Long, ctx: ProcessFunction[Record,Record]#OnTimerContext, out: Collector[Record] ): Unit =
    {
        Option( state.value ) match {
            case None => // ignore
            case Some( status ) => {
                out.collect( Record( status.key, status.value ) )
                state.clear
            }
        }
    }
}
3
votes

I think you can implement this use case using ProcessFunction

In which you have count property and windowEnd property. Using that you can decide when to collect the data.

public class TimeCountWindowProcessFunction extends ProcessFunction {

    protected long windowStart;
    protected long windowEnd;
    protected long count;
    private ValueState<CountPojo> state;

    public TimeCountWindowProcessFunction(long windowSize, long count) {

    this.windowSize = windowSize;
    this.count = count;

    }

@Override
    public void open(Configuration parameters) {

    TypeInformation<CountPojo> typeInformation = TypeInformation.of(new TypeHint<CountPojo>() {
    });
    ValueStateDescriptor<CountPojo> descriptor = new ValueStateDescriptor("test", typeInformation);

    state = getRuntimeContext().getState(descriptor);
}


    @Override
    public void processElement(CountPojo input, Context ctx, Collector<CountPojo> out)
            throws Exception {

    long timestamp = ctx.timestamp();
        windowStart = timestamp - (timestamp % windowSize);
        windowEnd = windowStart + windowSize;

        // retrieve the current count
        CountPojo current = (CountPojo) state.value();

        if (current == null) {

            current = new CountPojo();
        current.count = 1;

            ctx.timerService().registerEventTimeTimer(windowEnd);
        } else {

            current.count += 1;
        }

        if(current.count >= count) {
        out.collect(current);
    }

        // set the state's timestamp to the record's assigned event time timestamp
        current.setLastModified(ctx.timestamp());

        // write the state back
        state.update(current);
    }


    @Override
    public void onTimer(long timestamp, OnTimerContext ctx, Collector<CountPojo> out)
            throws Exception {


        if (windowEnd == timestamp) {

            out.collect(state.value());
        }

        state.clear();
    }
}

I hope this will helpful to you.