17
votes

I see this type of syntax a lot in some Lua source file I was reading lately, what does it mean, especially the second pair of brackets An example, line 8 in https://github.com/karpathy/char-rnn/blob/master/model/LSTM.lua

local LSTM = {}
function LSTM.lstm(input_size, rnn_size, n, dropout)
  dropout = dropout or 0 

  -- there will be 2*n+1 inputs
  local inputs = {}
  table.insert(inputs, nn.Identity()())  -- line 8
  -- ...

The source code of nn.Identity https://github.com/torch/nn/blob/master/Identity.lua

********** UPDATE **************

The ()() pattern is used in torch library 'nn' a lot. The first pair of bracket creates an object of the container/node, and the second pair of bracket references the depending node.

For example, y = nn.Linear(2,4)(x) means x connects to y, and the transformation is linear from 1*2 to 1*4. I just understand the usage, how it is wired seems to be answered by one of the answers below.

Anyway, the usage of the interface is well documented below. https://github.com/torch/nngraph/blob/master/README.md

3
If there is a question then a good way to resolve that problem is to pull out the expression nn.Identity() and give it a meaningful name. For example (but with a nearly meaningless name since I don't know what to call it): local identityProvider = nn.Identity()Tom Blodget

3 Answers

13
votes

In complement to Yu Hao's answer let me give some Torch related precisions:

  • nn.Identity() creates an identity module,
  • () called on this module triggers nn.Module __call__ (thanks to Torch class system that properly hooks up this into the metatable),
  • by default this __call__ method performs a forward / backward,
  • but here torch/nngraph is used and nngraph overrides this method as you can see here.

In consequence every nn.Identity()() calls has here for effect to return a nngraph.Node({module=self}) node where self refers to the current nn.Identity() instance.

--

Update: an illustration of this syntax in the context of LSTM-s can be found here:

local i2h = nn.Linear(input_size, 4 * rnn_size)(input)  -- input to hidden

If you’re unfamiliar with nngraph it probably seems strange that we’re constructing a module and already calling it once more with a graph node. What actually happens is that the second call converts the nn.Module to nngraph.gModule and the argument specifies it’s parent in the graph.

14
votes

No, ()() has no special meaning in Lua, it's just two call operators () together.

The operand is possibly a function that returns a function(or, a table that implements call metamethod). For example:

function foo()
  return function() print(42) end
end

foo()()   -- 42
2
votes
  • The first () calls the init function and the second () calls the call function
  • If the class doesn't posses either of these functions then the parent functions are called .
  • In the case of nn.Identity()() the nn.Identity has neither init function nor a call function hence the Identity parent nn.Module's init and call functions called .Attaching an illustration

    require 'torch'
    
    -- define some dummy A class
    local A = torch.class('A')
    function A:__init(stuff)
      self.stuff = stuff
      print('inside __init of A')
    end
    
    function A:__call__(arg1)
    print('inside __call__ of A')
    end
    
    -- define some dummy B class, inheriting from A
    local B,parent = torch.class('B', 'A')
    
    function B:__init(stuff)
      self.stuff = stuff
      print('inside __init of B')
    end
    
    function B:__call__(arg1)
    print('inside __call__ of B')
    end
    a=A()()
    b=B()()
    

    Output

    inside __init of A
    inside __call__ of A
    inside __init of B
    inside __call__ of B
    

Another code sample

    require 'torch'

    -- define some dummy A class
    local A = torch.class('A')
    function A:__init(stuff)
      self.stuff = stuff
      print('inside __init of A')
    end

    function A:__call__(arg1)
    print('inside __call__ of A')
    end

    -- define some dummy B class, inheriting from A
    local B,parent = torch.class('B', 'A')

    b=B()()

Output

    inside __init of A
    inside __call__ of A