代碼示例: pytest的簡單用法

demo文件目錄結構:html

 

demo_project/shell

|----------- __init__.pysocket

|----------- demo_test.pytcp

|----------- test_dbversion.pyspa

|----------- serialize.pycode

|----------- readme.txtorm

 

說明:htm

執行結果中包含print輸出。須要注意的是,在pytest中,print輸出不能直接使用print("xxx")的方式,由於會被pytest模塊納入capsys中。blog

由於對於pytest而言,print屬於sys.stdout,即輸出結果的類型。相似的,assert屬於sys.stderr的類型ip

具體的捕獲方式參考: https://docs.pytest.org/en/latest/capture.html#capturing-of-the-stdout-stderr-output

簡單貼一下官方示例代碼:

1 def test_myoutput(capsys):  # or use "capfd" for fd-level
2     print("hello")
3     sys.stderr.write("world\n")
4     captured = capsys.readouterr()
5     assert captured.out == "hello\n"
6     assert captured.err == "world\n"
7     print("next")
8     captured = capsys.readouterr()
9     assert captured.out == "next\n"

所以,在個人示例代碼中,使用with經過上下文管理,臨時關閉capsys,使得print的結果能被打印出來,具體效果參見最後的執行結果示例

 

各文件代碼以下:

__init__.py

 1 #

 

demo_test.py

 1 import pytest
 2 
 3 
 4 @pytest.mark.test
 5 def test_demo(capsys):
 6     with capsys.disabled():
 7         print(host_info_call_data())
 8     assert host_info_call_data() == '''hostname\tplatform\tcpus    cores\tsockets\tmemory(GB)
 9 rac1    Linux x86 64-bit    2    2    2    7.69    '''
10     return 1
11 
12 
13 def host_info_call_data():
14     data = '''hostname    platform    cpus    cores    sockets    memory(GB)
15 rac1    Linux x86 64-bit    2    2    2    7.69    '''
16     return data

 

 

test_dbversion.py

 1 import pytest
 2 
 3 from .serialize import serialize
 4 
 5 
 6 def script_data(num):
 7     data = {
 8         1: '''rdbms_version\tcluster_version
 9 11.2.0.4.0\t11.2.0.4.0
10 ''',
11         2: '''rdbms_version\tcluster_version
12 11.2.0.4.0    -
13 '''
14     }
15     return data[num]
16 
17 
18 @pytest.mark.test
19 @pytest.mark.shells
20 def test_all_version_exist(capsys):
21     data = script_data(1)
22     row, column, data = serialize(data)
23     with capsys.disabled():
24         print(data)
25     assert data == [['rdbms_version', 'cluster_version'],
26                     ['11.2.0.4.0', '11.2.0.4.0']]

 

 

serialize.py

 1 def serialize(data):
 2     if isinstance(data, list):
 3         data = str(data)
 4     data = data.split('\n')
 5     if data[-1] == '':
 6         data = data[:-1]
 7     data = [i.split('\t') for i in data if isinstance(i, str)]
 8     row = len(data)
 9     column = len(data[0])
10     return row, column, data

 

 

 

執行結果示例:

相關文章
相關標籤/搜索