Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
determine spec shape only at mock construction time
  • Loading branch information
carljm committed Dec 16, 2022
commit 6595272b2e5c64310c1eb4e7d1edba0689555c55
8 changes: 7 additions & 1 deletion Lib/test/test_unittest/testmock/testasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,15 @@ def test_spec_normal_methods_on_class_with_mock(self):
def test_spec_async_attributes_instance(self):
async_instance = AsyncClass()
async_instance.async_func_attr = async_func
async_instance.later_async_func_attr = normal_func

mock_async_instance = Mock(spec_set=async_instance)

async_instance.later_async_func_attr = async_func

mock_async_instance = Mock(async_instance)
self.assertIsInstance(mock_async_instance.async_func_attr, AsyncMock)
# only the shape of the spec at the time of mock construction matters
self.assertNotIsInstance(mock_async_instance.later_async_func_attr, AsyncMock)

def test_spec_mock_type_kw(self):
def inner_test(mock_type):
Expand Down
16 changes: 10 additions & 6 deletions Lib/unittest/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,9 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,

_spec_class = None
_spec_signature = None
_spec_obj = None
_spec_asyncs = []

if spec is not None and not _is_list(spec):
_spec_obj = spec
if isinstance(spec, type):
_spec_class = spec
else:
Expand All @@ -518,14 +517,20 @@ def _mock_add_spec(self, spec, spec_set, _spec_as_instance=False,
_spec_as_instance, _eat_self)
_spec_signature = res and res[1]

spec = dir(spec)
spec_list = dir(spec)

for attr in spec_list:
if iscoroutinefunction(getattr(spec, attr, None)):
_spec_asyncs.append(attr)

spec = spec_list

__dict__ = self.__dict__
__dict__['_spec_class'] = _spec_class
__dict__['_spec_obj'] = _spec_obj
__dict__['_spec_set'] = spec_set
__dict__['_spec_signature'] = _spec_signature
__dict__['_mock_methods'] = spec
__dict__['_spec_asyncs'] = _spec_asyncs

def __get_return_value(self):
ret = self._mock_return_value
Expand Down Expand Up @@ -1015,8 +1020,7 @@ def _get_child_mock(self, /, **kw):
For non-callable mocks the callable variant will be used (rather than
any custom subclass)."""
_new_name = kw.get("_new_name")
_spec_val = getattr(self.__dict__["_spec_obj"], _new_name, None)
if _spec_val is not None and asyncio.iscoroutinefunction(_spec_val):
if _new_name in self.__dict__['_spec_asyncs']:
return AsyncMock(**kw)

if self._mock_sealed:
Comment thread
carljm marked this conversation as resolved.
Expand Down