2
votes

I am using Python 3.5.1 on 64 bit windows. My problem is Python seems to ignore __eq__ and __lt__ operators on a user-defined class when sorting. Using a custom sorting key is a workaround that doesn't work when trying to sort tuples containing instances of this class.

Example:

class Symbol:
  def __init__(self, name, is_terminal = False):
    self.name = name
    self.is_terminal = is_terminal
  def __eq__(self, other):
    return (self.is_terminal, self.name) == (other.is_terminal, other.name)
  def __lt__(self, other):
    return (self.is_terminal, self.name) < (other.is_terminal, other.name)

symbols = set()
for s in "abcdef":
  symbols.add(Symbol(s))

sorted_symbols = sorted(symbols)
# sorted_symbols now contain the symbols in random order

Using the functools.total_ordering decorator does not help.

My question is how do I define ordering for a user class in Python 3?

1
Sorting most definitely uses those methods. How have you determined that they are not used? And why is a sort key not a work-around? key=lambda s: (s.is_terminal, s.name) would give you the same result. - Martijn Pieters
What is the expected result? What do you get instead? - Andrea Corbellini
Your Symbol class is sortable, but not hashable as written. If you define __eq__, you also need to define __hash__ or you can't use the instances as dictionary keys or add them to sets. If you changed symbols to be a list (and use append rather than add in the loop), your code would work just fine. - Blckknght

1 Answers

4
votes

Python does not ignore __eq__ and __lt__, at least not if you actually used @functools.total_ordering:

>>> from functools import total_ordering
>>> @total_ordering
... class Symbol:
...     def __init__(self, name, is_terminal=False):
...         self.name = name
...         self.is_terminal = is_terminal
...     def __repr__(self):
...         return 'Symbol({0.name!r}, is_terminal={0.is_terminal!r})'.format(self)
...     def __hash__(self):
...         return hash(self.name) ^ hash(self.is_terminal)
...     def __eq__(self, other):
...         print('{} __eq__ {}'.format(self, other))
...         return (self.is_terminal, self.name) == (other.is_terminal, other.name)
...     def __lt__(self, other):
...         print('{} __lt__ {}'.format(self, other))
...         return (self.is_terminal, self.name) < (other.is_terminal, other.name)
...
>>> symbols = set()
>>> for s in "abcdef":
...     symbols.add(Symbol(s))
...
>>> sorted(symbols)
Symbol('f', is_terminal=False) __lt__ Symbol('c', is_terminal=False)
Symbol('a', is_terminal=False) __lt__ Symbol('f', is_terminal=False)
Symbol('a', is_terminal=False) __lt__ Symbol('f', is_terminal=False)
Symbol('a', is_terminal=False) __lt__ Symbol('c', is_terminal=False)
Symbol('b', is_terminal=False) __lt__ Symbol('c', is_terminal=False)
Symbol('b', is_terminal=False) __lt__ Symbol('a', is_terminal=False)
Symbol('d', is_terminal=False) __lt__ Symbol('c', is_terminal=False)
Symbol('d', is_terminal=False) __lt__ Symbol('f', is_terminal=False)
Symbol('e', is_terminal=False) __lt__ Symbol('c', is_terminal=False)
Symbol('e', is_terminal=False) __lt__ Symbol('f', is_terminal=False)
Symbol('e', is_terminal=False) __lt__ Symbol('d', is_terminal=False)
[Symbol('a', is_terminal=False), Symbol('b', is_terminal=False), Symbol('c', is_terminal=False), Symbol('d', is_terminal=False), Symbol('e', is_terminal=False), Symbol('f', is_terminal=False)]

In fact, sorting even works without @total_ordering because the TimSort implementation only uses __lt__; this is explicitly document:

This method sorts the list in place, using only < comparisons between items.

A sort key is also an option, just return the (is_terminal, name) tuple from the key:

>>> sorted(symbols, key=lambda s: (s.is_terminal, s.name))
[Symbol('a', is_terminal=False), Symbol('b', is_terminal=False), Symbol('c', is_terminal=False), Symbol('d', is_terminal=False), Symbol('e', is_terminal=False), Symbol('f', is_terminal=False)]

Note that now the __lt__ method is never called, because the sort key is used instead.