3
votes

This question is an extension if one I asked earlier today. Basically, I am trying to write an array comprehension in Julia that calls a function f(x) whose output is a random number. When a random number less than 0.5 is reached, I want it to kill the function. I was able to write the following code:

X=[f(i) for i in 1:1:100 if (j=f(i) ;j < 0.5 ? false: j>0.5)]

The problem with this is that this calls two separate instances of f(x), and because f(x) is random every time, the above won't kill the for loop at the correct instance. I tried

X=[J=f(i) for i in 1:1:100 if (J < 0.5 ? false: J>0.5)]

As an attempt to save that particular random number, but then it tells me J is not defined. Is there any way to save this particular random number to perform my array comprehension?

5
Why not just write a loop? This is a very simple loop.Chris Rackauckas
This is really an abuse of array comprehensions. Write a loop.Fengyang Wang

5 Answers

6
votes

Insisting on a one-line solution and inspired by @TasosPapastylianou, a fast solution would be :

X = ( r=Vector{Float64}() ; 
  any(i->(v=f(i) ; v>0.5 ? ( push!(r,v) ; false) : true), 1:100) 
  ; r )

[ the one-line is split to three because it is a little long ;) ]

Since, f is missing, to test this copy-paste this version with rand:

(r=ones(0); any(i->(v=rand(); v>0.5 ? (push!(r,v); false) : true), 1:10); r)

It benchmarks about 10% slower than Fengyang's function. The clever bit is leveraging any's short-circuit implementation.

ADDENDUM: To generalize here is a version of Fengyang's takewhile to abstract the answer to this question:

collectwhilecond(f,cond,itr) = begin
    r=Vector{typeof(f(first(itr)))}()
    all(x->(y=f(x); cond(y) ? (push!(r,y);true):false),itr)
    return r
end

Now, we can implement the answer above as (with joker as f):

julia> joker(i) = 1.0 + 4*rand() - log(i)

julia> collectwhilecond(joker, x->x>=0.5, 1:100)
3-element Array{Float64,1}:
 4.14222
 3.42955
 2.76387

collectwhilecond is also type stable if Julia infers f's return type.

EDIT: Using @tim's suggested method of inferring return type of f without pulling an element of itr and without risking an unstable f generating an error, the new collectwhilecond is:

collectwhilecond(f,cond,itr) = begin
    t = Base.promote_op(f,eltype(itr))  # unofficial and subject to change
    r = Vector{t}()
    all( x -> ( y=f(x) ; cond(y) ? (push!(r,y) ; true) : false), itr )
    return r
end
5
votes

What you're trying to do is essentially a simple filter operation:

filter(x -> x >= 0.5, [f(i) for i in 1:10])

This is basically what we used to rely on in the first place before the if part was implemented in list comprehensions in julia before v0.5


EDIT: As Dan pointed out, you may be after keeping all elements until the first element that is <0.5 instead, e.g.:

L = [f(i) for i in 1:10]; L[1 : findfirst(L.<0.5) - 1]

However, in this case, as others have pointed out, you might as well go for a normal for loop. A list comprehension will always process the whole list first, so it will not be faster. You could use a generator, but then you'd have to create your own special mechanism to make it stop at the right state (as Fengyang has suggested with takewhile).

So to answer the question in the comment, the fastest thing you could do in this case is a normal for loop that breaks appropriately. Furthermore, it is best wrapped in a function rather than evaluated globally, and it will speed up even further if you specify the type of your variables.

4
votes

You can do this with what's usually called takewhile on a generator

X = collect(takewhile(x -> x ≥ 0.5, Generator(f, 1:100)))

which requires an implementation of takewhile, such as the one in this blog post, or your own (it's not too hard to do). takewhile has a reasonably easy name to read and is as concise as you might like.

However I think often it's both more readable and more convenient to write a loop:

X = Float64[]
for i = 1:100
    j = f(i)
    if j < 0.5
        break
    else
        push!(X, j)
    end
end
3
votes

As suggested, using a loop is the right way to go. But if trying 'other' solutions, the following is short and confusing and educational about Channels which are seldom mentioned in Stack Overflow:

collect(Channel(c->begin
        i=1
        while true 
            v = rand()
            if v<0.5 || i>100 return else put!(c,v) end
            i+=1 
        end 
    end, ctype=Float64))

Replace rand() with f(i) as appropriate. And BTW don't use this solution because it is 1000x slower than a simple loop. Perhaps of value if Channel is a RemoteChannel and f(i) is some big stochastic simulation.

2
votes

If you want to squeeze the maximum possible performance, I'd say there are two options depending on which is the bottleneck.

  1. f is very fast, but allocations are an issue. The following code computes f twice but saves on the allocation part (as push! will sometimes reallocate memory, see here ):

i = findfirst(t -> f(t) > 0.5, 1:100) w = f.(1:(i-1))

  1. f is very slow and computing it twice is too expensive. Then just do a loop with push! as already recommended. You may want to look at sizehint! to further improve performance.

In general, you don't need to ask what is slower and what is faster: ypu can simply benchmark your particular use case with the excellent BenchmarkTools package.