Python-无法模拟继承类的调用
我有这堂课Python-无法模拟继承类的调用,python,python-3.x,python-unittest,Python,Python 3.x,Python Unittest,我有这堂课 def main(args): if type == train_pipeline_type: strategy = TrainPipelineStrategy() else: strategy = TestPipelineStrategy() for table in fetch_table_information_by_region(region): split_required = DataUtils.lo
def main(args):
if type == train_pipeline_type:
strategy = TrainPipelineStrategy()
else:
strategy = TestPipelineStrategy()
for table in fetch_table_information_by_region(region):
split_required = DataUtils.load_from_dict(table, "split_required")
if split_required:
strategy.split(spark=spark, table_name=table_name,
data_loc=filtered_data_location, partition_column=partition_column,
split_output_dir= split_output_dir)
logger.info("Data Split for table : {} completed".format(table_name))
我的TrainPipelineStrategy和TestPipelineStrategy看起来像这样-
class PipelineTypeStrategy(object):
def partition_data(self, x):
# Something
def prepare_split_data(self, y):
# Something
def write_split_data(self, z):
# Something
def split(self, p):
# Something
class TrainPipelineStrategy(PipelineTypeStrategy):
""""""
class TestPipelineStrategy(PipelineTypeStrategy):
def write_split_data(self, y):
# Something else
我的测试用例-
我需要测试在main方法中模拟split功能调用split的次数
以下是我尝试过的-
@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
def test_split_data_main_split_data_call_count(self, fake_train):
fake_train_functions = mock.Mock()
fake_train_functions.split.return_value = None
fake_train.return_value = fake_train_functions
test_args = ["", "--x=6"]
SplitData.main(args=test_args)
assert fake_train_functions.split.call_count == 10
当我尝试运行测试时,它会创建模拟,但最终会调用实际的分割函数。我做错了什么?此代码的主要问题是,如果
TrainPipelineStrategy
是PipelineTypeStrategy
的嵌套类,那么设置修补程序的方法将是,但是TrainPipelineStrategy
是PipelineTypeStrategy
的子类
由于TrainPipelineStrategy
继承自PipelineTypeStrategy
它可以直接访问split
,因此您可以修补split
,而无需参考PipelineTypeStrategy
(除非您特别希望修补在PipelineTypeStrategy
中定义的split
版本)
但是,如果您只想模拟PipelineTypeStrategy
类的split
方法,则应该使用patch.object
装饰器模拟split
,而不是模拟整个类,因为它更干净。下面是一个示例:
class TestClass(unittest.TestCase):
@patch.object(TrainPipelineStrategy, 'split', return_value=None)
def test_split_data_main_split_data_call_count(self, mock_split):
test_args = ["", "--x=6"]
SplitData.main(args=test_args)
self.assertEqual(mock_split.call_count, 10)
我不能理解你的代码,但是猴子补丁比看起来更难。如果你把你想要模拟的东西作为参数传递给SUT,这会更容易。