test_trunks_service.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import pytest
  2. from service.trunks_service import TrunksService
  3. from model.trunks_model import Trunks
  4. from sqlalchemy.exc import IntegrityError
  5. @pytest.fixture(scope="module")
  6. def trunks_service():
  7. return TrunksService()
  8. @pytest.fixture
  9. def test_trunk_data():
  10. return {
  11. "content": """本切片内容来自:《急诊与灾难医学(第4版)》 ,第二章急性发热 ,第一节概述
  12. 第二章 急性发热
  13. 发热(fever)是机体在内、外致热原作用下,或由于各种病因导致体温调节中枢功能障碍,而出现 的以体温升高超出正常范围为主要表现的临床症状。通常体表温度≥37.3℃可诊断为发热。热程在 2 周以内者为急性发热。
  14. 第一节 | 概 述
  15. 急性发热可分为感染性发热和非感染性发热。感染性发热较为多见,常见的病原体包括细菌、病 毒、衣原体、支原体、立克次体、螺旋体、真菌、原虫、蠕虫等,涉及的部位可由浅表组织到深部组织。非 感染性发热的病因包括结缔组织病、超敏反应性疾病、过敏性疾病、肿瘤性疾病、内分泌和代谢性疾 病、中枢神经系统疾病、散热障碍、创伤、烧伤、手术以及其他不明原因。
  16. 急性发热起病急骤,常有受凉、劳累或进不洁饮食史。发热及其伴随症状和体征多种多样,详细 询问病史,连续观察热程、热型,仔细查体,完善实验室检查,是发热诊断和鉴别诊断的重要依据。而 体温的高低并不是判断疾病危重程度的唯一依据。
  17. 发热是一个病因较为复杂的临床症状,而不是一种疾病,是机体对致病因素的一种全身性的代偿 反应。发热的治疗包括正确使用物理降温和解热药物,合理应用抗生素以及糖皮质激素。""",
  18. "file_path": "test_path.pdf",
  19. "type": "default"
  20. }
  21. class TestTrunksServiceCRUD:
  22. def test_create_and_get_trunk(self, trunks_service, test_trunk_data):
  23. # 测试创建和查询
  24. created = trunks_service.create_trunk(test_trunk_data)
  25. assert created.id is not None
  26. def test_update_trunk(self, trunks_service, test_trunk_data):
  27. trunk = trunks_service.create_trunk(test_trunk_data)
  28. updated = trunks_service.update_trunk(trunk.id, {"content": "更新内容"})
  29. assert updated.content == "更新内容"
  30. def test_delete_trunk(self, trunks_service, test_trunk_data):
  31. trunk = trunks_service.create_trunk(test_trunk_data)
  32. assert trunks_service.delete_trunk(trunk.id)
  33. assert trunks_service.get_trunk_by_id(trunk.id) is None
  34. class TestSearchOperations:
  35. def test_vector_search(self, trunks_service, test_trunk_data):
  36. trunks_service.create_trunk(test_trunk_data)
  37. results = trunks_service.search_by_vector("各种致病因素作用引起的有效循环血容量急剧减少")
  38. print("搜索结果:", results)
  39. assert len(results) > 0
  40. # def test_fulltext_search(self, trunks_service, test_trunk_data):
  41. # trunks_service.create_trunk(test_trunk_data)
  42. # results = trunks_service.fulltext_search("测试")
  43. # assert len(results) > 0
  44. class TestExceptionCases:
  45. def test_duplicate_id(self, trunks_service, test_trunk_data):
  46. with pytest.raises(IntegrityError):
  47. trunk1 = trunks_service.create_trunk(test_trunk_data)
  48. test_trunk_data["id"] = trunk1.id
  49. trunks_service.create_trunk(test_trunk_data)
  50. def test_invalid_vector_dimension(self, trunks_service, test_trunk_data):
  51. with pytest.raises(ValueError):
  52. invalid_data = test_trunk_data.copy()
  53. invalid_data["embedding"] = [0.1]*100
  54. trunks_service.create_trunk(invalid_data)
  55. @pytest.fixture
  56. def trunk_factory():
  57. class TrunkFactory:
  58. @staticmethod
  59. def create(**overrides):
  60. defaults = {
  61. "content": "工厂内容",
  62. "file_path": "factory_path.pdf",
  63. "type": "default"
  64. }
  65. return {**defaults, **overrides}
  66. return TrunkFactory()
  67. class TestBatchCreateFromDirectory:
  68. def test_batch_create_from_directory(self, trunks_service, test_data_dir):
  69. # 使用现有目录路径
  70. base_path = Path(r'E:\project\vscode\files')
  71. # 遍历目录并创建trunk
  72. created_ids = []
  73. for txt_path in base_path.glob('**/*_split_*.txt'):
  74. relative_path = txt_path.relative_to(base_path.parent.parent)
  75. with open(txt_path, 'r', encoding='utf-8') as f:
  76. trunk_data = {
  77. "content": f.read(),
  78. "file_path": str(relative_path).replace('\\', '/')
  79. }
  80. trunk = trunks_service.create_trunk(trunk_data)
  81. created_ids.append(trunk.id)
  82. # 验证数据库记录
  83. for trunk_id in created_ids:
  84. db_trunk = trunks_service.get_trunk_by_id(trunk_id)
  85. assert db_trunk is not None
  86. assert ".txt" in db_trunk.file_path
  87. assert "_split_" in db_trunk.file_path
  88. assert len(db_trunk.content) > 0