5
votes

I can't seem to get my head around mocking in Python. I have a global function:

a.py:

def has_permission(args):
    ret_val = ...get-true-or-false...
    return ret_val

b.py:

class MySerializer(HyperlinkedModelSerializer):

     def get_fields():
         fields = super().get_fields()
         for f in :
             if has_permission(...):
                 ret_val[f.name] = fields[f]
         return ret_val

c.py:

class CountrySerializer(MySerializer):
    class Meta:
        model = Country

Question: Now i want to test c.py, but i want to mock the has_permission function that is defined in a.py, but is called in the get_fields-method of the class MySerializer that is defined in b.py ... How do i do that?

I've tried things like:

@patch('b.MySerializer.has_permission')

and

@patch('b.MySerializer.get_fields.has_permission')

and

@patch('a.has_permission')

But everything i try either just doesn't work and has_permission is still executed, or python complains about that it can't find the attribute 'has_permission'

with the patching done in:

test.py

class TestSerializerFields(TestCase):
    @patch(... the above examples....)
    def test_my_country_serializer():
        s = CountrySerializer()
        self..assertTrue(issubclass(my_serializer_fields.MyCharField, type(s.get_fields()['field1'])))
2
Where are you applying the patch?vks
@vks: in a test, presumably.Martijn Pieters
@patch('b.MySerializer.has_permission') is wrong because the has_permission function does not in any way belong to the class. what you pass to patch is an import path. the subtlety is that, although defined in a.py, once you've imported it in b.py` it is also importable from b.py ...and it's the copy imported in b.py that you want to patchAnentropic

2 Answers

13
votes

You need to patch the global in the b module:

@patch('b.has_permission')

because that's where your code looks for it.

Also see the Where to patch section of the mock documentation.

4
votes

You need to patch the method where it exists at the time your test runs. If you try and patch the method where it is defined after the test code has already imported it, then the patch will have no effect. At the point where the @patch(...) executes, the test code under test has already grabbed the global method into its own module.

Here is an example:

app/util/config.py:

# This is the global method we want to mock
def is_search_enabled():
    return True

app/service/searcher.py:

# Here is where that global method will be imported 
#  when this file is first imported
from app.util.config import is_search_enabled

class Searcher:
    def __init__(self, api_service):
        self._api_service = api_service

    def search(self):
        if not is_search_enabled():
            return None
        return self._api_service.perform_request('/search')

test/service/test_searcher.py:

from unittest.mock import patch, Mock
# The next line will cause the imports of `searcher.py` to execute...
from app.service.searcher import Searcher
# At this point, searcher.py has imported is_search_enabled into its module.
# If you later try and patch the method at its definition 
#  (app.util.config.is_search_enabled), it will have no effect because 
#  searcher.py won't look there again.

class MockApiService:
    pass

class TestSearcher:

    # By the time this executes, `is_search_enabled` has already been
    #  imported into `app.service.searcher`.  So that is where we must
    #  patch it.
    @patch('app.service.searcher.is_search_enabled')
    def test_no_search_when_disabled(self, mock_is_search_enabled):
        mock_is_search_enabled.return_value = False
        mock_api_service = MockApiService()
        mock_api_service.perform_request = Mock()
        searcher = Searcher(mock_api_service)

        results = searcher.search()

        assert results is None
        mock_api_service.perform_request.assert_not_called()

    # (For completeness' sake, make sure the code actually works when search is enabled...)
    def test_search(self):
        mock_api_service = MockApiService()
        mock_api_service.perform_request = mock_perform_request = Mock()
        searcher = Searcher(mock_api_service)
        expected_results = [1, 2, 3]
        mock_perform_request.return_value = expected_results

        actual_results = searcher.search()

        assert actual_results == expected_results
        mock_api_service.perform_request.assert_called_once_with('/search')