diff --git a/prompts/codegen.txt b/prompts/codegen.txt index 6ff20dd..cf95df9 100644 --- a/prompts/codegen.txt +++ b/prompts/codegen.txt @@ -1 +1 @@ -Write Python code that {task} +Write Python code that {input} diff --git a/prompts/custom.txt b/prompts/custom.txt new file mode 100644 index 0000000..ddd0a6d --- /dev/null +++ b/prompts/custom.txt @@ -0,0 +1 @@ +{input} \ No newline at end of file diff --git a/prompts/summarization.txt b/prompts/summarization.txt index 5598b9e..6526950 100644 --- a/prompts/summarization.txt +++ b/prompts/summarization.txt @@ -15,4 +15,4 @@ Example: [User supplies text] [Model outputs only the summary in Russian, no code or tables] -'{text}' +'{input}' diff --git a/prompts/translation.txt b/prompts/translation.txt index bb34761..b88f8aa 100644 --- a/prompts/translation.txt +++ b/prompts/translation.txt @@ -1 +1 @@ -Translate the following English text to Russian: '{text}' +Translate the following text to Russian: '{input}' diff --git a/src/benchmarks/base.py b/src/benchmarks/base.py index 81520be..f13c5f5 100644 --- a/src/benchmarks/base.py +++ b/src/benchmarks/base.py @@ -1,23 +1,91 @@ import logging import time -from typing import Dict, Any, List +import os +import re +from typing import Dict, Any, List, Optional from abc import ABC, abstractmethod from models.ollama_client import OllamaClient +from constants import TEST_SEPARATOR class Benchmark(ABC): """Базовый класс для всех бенчмарков.""" - def __init__(self, name: str): + def __init__(self, name: str, prompt_path: str, test_data_dir: str): """ Инициализация бенчмарка. Args: name: Название бенчмарка + prompt_path: Путь к файлу с промптом + test_data_dir: Путь к каталогу с тестовыми данными """ self.name = name + self.prompt_path = prompt_path + self.test_data_dir = test_data_dir self.logger = logging.getLogger(__name__) - @abstractmethod + # Загружаем и валидируем универсальный промпт + self.universal_prompt = self._load_and_validate_prompt(prompt_path) + + def _load_and_validate_prompt(self, prompt_path: str) -> str: + """ + Загрузка и валидация промпта. + + Args: + prompt_path: Путь к файлу с промптом + + Returns: + Содержимое промпта + + Raises: + ValueError: Если промпт не содержит ожидаемые параметры + """ + try: + with open(prompt_path, 'r', encoding='utf-8') as f: + prompt = f.read().strip() + + # Базовая валидация: промпт должен содержать хотя бы один параметр форматирования + if not re.search(r'\{[^}]+\}', prompt): + raise ValueError(f"Промпт {prompt_path} не содержит параметров форматирования") + + return prompt + except FileNotFoundError: + raise FileNotFoundError(f"Файл промпта не найден: {prompt_path}") + except Exception as e: + raise ValueError(f"Ошибка загрузки промпта {prompt_path}: {e}") + + def _load_test_data(self) -> List[Dict[str, Any]]: + """ + Загрузка тестовых данных. + + Returns: + Список тестовых случаев + """ + test_data = [] + + if not os.path.exists(self.test_data_dir): + self.logger.warning(f"Каталог с тестовыми данными не найден: {self.test_data_dir}") + return test_data + + for filename in os.listdir(self.test_data_dir): + if filename.endswith('.txt'): + filepath = os.path.join(self.test_data_dir, filename) + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + # Разделяем по разделителю + parts = content.split(TEST_SEPARATOR, 1) + if len(parts) == 2: + test_data.append({ + 'name': filename.replace('.txt', ''), + 'prompt': self.universal_prompt, + 'expected': parts[1] + }) + except Exception as e: + self.logger.error(f"Ошибка загрузки тестового случая {filename}: {e}") + + return test_data + def load_test_data(self) -> List[Dict[str, Any]]: """ Загрузка тестовых данных. @@ -25,12 +93,34 @@ class Benchmark(ABC): Returns: Список тестовых случаев """ - pass + test_data = [] + + if not os.path.exists(self.test_data_dir): + self.logger.warning(f"Каталог с тестовыми данными не найден: {self.test_data_dir}") + return test_data + + for filename in os.listdir(self.test_data_dir): + if filename.endswith('.txt'): + filepath = os.path.join(self.test_data_dir, filename) + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + # Разделяем по разделителю + parts = content.split(TEST_SEPARATOR, 1) + if len(parts) == 2: + test_data.append({ + 'name': filename.replace('.txt', ''), + 'prompt': self.universal_prompt.format(input=parts[0]), + 'expected': parts[1] + }) + except Exception as e: + self.logger.error(f"Ошибка загрузки тестового случая {filename}: {e}") + + return test_data - @abstractmethod def evaluate(self, model_response: str, expected: str) -> float: """ - Оценка качества ответа модели. + Оценка качества ответа модели (по умолчанию на основе F1-score). Args: model_response: Ответ от модели @@ -39,7 +129,23 @@ class Benchmark(ABC): Returns: Метрика качества (0-1) """ - pass + # Простая оценка на основе совпадения токенов + model_tokens = set(model_response.lower().split()) + expected_tokens = set(expected.lower().split()) + + if len(expected_tokens) == 0: + return 0.0 + + intersection = model_tokens.intersection(expected_tokens) + precision = len(intersection) / len(model_tokens) if model_tokens else 0.0 + recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0 + + # F1-score + if (precision + recall) == 0: + return 0.0 + f1 = 2 * (precision * recall) / (precision + recall) + + return round(f1, 3) def run(self, ollama_client: OllamaClient, model_name: str, context_size: int = 32000) -> Dict[str, Any]: """ diff --git a/src/benchmarks/codegen.py b/src/benchmarks/codegen.py index 21718d5..534db9c 100644 --- a/src/benchmarks/codegen.py +++ b/src/benchmarks/codegen.py @@ -9,61 +9,4 @@ class CodegenBenchmark(Benchmark): """Бенчмарк для тестирования генерации кода.""" def __init__(self): - super().__init__("codegen") - # Загружаем универсальный промпт - with open('prompts/codegen.txt', 'r', encoding='utf-8') as f: - self.universal_prompt = f.read().strip() - - def load_test_data(self) -> List[Dict[str, Any]]: - """ - Загрузка тестовых данных для генерации кода. - - Returns: - Список тестовых случаев - """ - test_data = [] - data_dir = "tests/codegen" - - for filename in os.listdir(data_dir): - if filename.endswith('.txt'): - with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as f: - content = f.read() - # Разделяем по разделителю - parts = content.split(TEST_SEPARATOR, 1) - if len(parts) == 2: - test_data.append({ - 'name': filename.replace('.txt', ''), - 'prompt': self.universal_prompt.format(task=parts[0]), - 'expected': parts[1] - }) - - return test_data - - def evaluate(self, model_response: str, expected: str) -> float: - """ - Оценка качества сгенерированного кода. - - Args: - model_response: Ответ от модели - expected: Ожидаемый ответ - - Returns: - Метрика качества (0-1) - """ - # Простая оценка на основе совпадения токенов - model_tokens = set(model_response.lower().split()) - expected_tokens = set(expected.lower().split()) - - if len(expected_tokens) == 0: - return 0.0 - - intersection = model_tokens.intersection(expected_tokens) - precision = len(intersection) / len(model_tokens) if model_tokens else 0.0 - recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0 - - # F1-score - if (precision + recall) == 0: - return 0.0 - f1 = 2 * (precision * recall) / (precision + recall) - - return round(f1, 3) + super().__init__("codegen", "prompts/codegen.txt", "tests/codegen") diff --git a/src/benchmarks/custom.py b/src/benchmarks/custom.py new file mode 100644 index 0000000..4e939e4 --- /dev/null +++ b/src/benchmarks/custom.py @@ -0,0 +1,8 @@ +import logging +from benchmarks.base import Benchmark + +class CustomBenchmark(Benchmark): + """Бенчмарк для универсальных тестов, где инструкция передается в самом тест-кейсе.""" + + def __init__(self): + super().__init__("custom", "prompts/custom.txt", "tests/custom") \ No newline at end of file diff --git a/src/benchmarks/summarization.py b/src/benchmarks/summarization.py index e395a9d..14b1254 100644 --- a/src/benchmarks/summarization.py +++ b/src/benchmarks/summarization.py @@ -9,61 +9,4 @@ class SummarizationBenchmark(Benchmark): """Бенчмарк для тестирования пересказов.""" def __init__(self): - super().__init__("summarization") - # Загружаем универсальный промпт - with open('prompts/summarization.txt', 'r', encoding='utf-8') as f: - self.universal_prompt = f.read().strip() - - def load_test_data(self) -> List[Dict[str, Any]]: - """ - Загрузка тестовых данных для пересказов. - - Returns: - Список тестовых случаев - """ - test_data = [] - data_dir = "tests/summarization" - - for filename in os.listdir(data_dir): - if filename.endswith('.txt'): - with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as f: - content = f.read() - # Разделяем по разделителю - parts = content.split(TEST_SEPARATOR, 1) - if len(parts) == 2: - test_data.append({ - 'name': filename.replace('.txt', ''), - 'prompt': self.universal_prompt.format(text=parts[0]), - 'expected': parts[1] - }) - - return test_data - - def evaluate(self, model_response: str, expected: str) -> float: - """ - Оценка качества пересказа. - - Args: - model_response: Ответ от модели - expected: Ожидаемый ответ - - Returns: - Метрика качества (0-1) - """ - # Простая оценка на основе совпадения токенов - model_tokens = set(model_response.lower().split()) - expected_tokens = set(expected.lower().split()) - - if len(expected_tokens) == 0: - return 0.0 - - intersection = model_tokens.intersection(expected_tokens) - precision = len(intersection) / len(model_tokens) if model_tokens else 0.0 - recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0 - - # F1-score - if (precision + recall) == 0: - return 0.0 - f1 = 2 * (precision * recall) / (precision + recall) - - return round(f1, 3) + super().__init__("summarization", "prompts/summarization.txt", "tests/summarization") diff --git a/src/benchmarks/translation.py b/src/benchmarks/translation.py index 834494b..67353d0 100644 --- a/src/benchmarks/translation.py +++ b/src/benchmarks/translation.py @@ -9,62 +9,4 @@ class TranslationBenchmark(Benchmark): """Бенчмарк для тестирования перевода.""" def __init__(self): - super().__init__("translation") - # Загружаем универсальный промпт - with open('prompts/translation.txt', 'r', encoding='utf-8') as f: - self.universal_prompt = f.read().strip() - - def load_test_data(self) -> List[Dict[str, Any]]: - """ - Загрузка тестовых данных для перевода. - - Returns: - Список тестовых случаев - """ - test_data = [] - data_dir = "tests/translation" - - for filename in os.listdir(data_dir): - if filename.endswith('.txt'): - with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as f: - content = f.read() - # Разделяем по разделителю - parts = content.split(TEST_SEPARATOR, 1) - if len(parts) == 2: - test_data.append({ - 'name': filename.replace('.txt', ''), - 'prompt': self.universal_prompt.format(text=parts[0]), - 'expected': parts[1] - }) - - return test_data - - def evaluate(self, model_response: str, expected: str) -> float: - """ - Оценка качества перевода. - - Args: - model_response: Ответ от модели - expected: Ожидаемый ответ - - Returns: - Метрика качества (0-1) - """ - # Простая оценка на основе совпадения токенов - # В реальном проекте можно использовать более сложные метрики - model_tokens = set(model_response.lower().split()) - expected_tokens = set(expected.lower().split()) - - if len(expected_tokens) == 0: - return 0.0 - - intersection = model_tokens.intersection(expected_tokens) - precision = len(intersection) / len(model_tokens) if model_tokens else 0.0 - recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0 - - # F1-score - if (precision + recall) == 0: - return 0.0 - f1 = 2 * (precision * recall) / (precision + recall) - - return round(f1, 3) + super().__init__("translation", "prompts/translation.txt", "tests/translation") diff --git a/src/main.py b/src/main.py index 5f6da71..cdcff7c 100644 --- a/src/main.py +++ b/src/main.py @@ -5,6 +5,7 @@ from models.ollama_client import OllamaClient from benchmarks.translation import TranslationBenchmark from benchmarks.summarization import SummarizationBenchmark from benchmarks.codegen import CodegenBenchmark +from benchmarks.custom import CustomBenchmark from utils.report import ReportGenerator def setup_logging(verbose: bool = False): @@ -34,7 +35,8 @@ def run_benchmarks(ollama_client: OllamaClient, model_name: str, benchmarks: Lis benchmark_classes = { 'translation': TranslationBenchmark, 'summarization': SummarizationBenchmark, - 'codegen': CodegenBenchmark + 'codegen': CodegenBenchmark, + 'custom': CustomBenchmark } results = [] @@ -58,7 +60,7 @@ def main(): parser.add_argument('-u', '--ollama-url', default='http://localhost:11434', help='URL подключения к Ollama серверу') parser.add_argument('-c', '--context-size', type=int, default=32000, help='Размер контекста для модели (по умолчанию 32000)') parser.add_argument('-b', '--benchmarks', nargs='+', default=['translation', 'summarization', 'codegen'], - help='Список бенчмарков для выполнения (translation, summarization, codegen)') + help='Список бенчмарков для выполнения (translation, summarization, codegen, custom)') parser.add_argument('-o', '--output', default='results', help='Директория для сохранения результатов') parser.add_argument('-v', '--verbose', action='store_true', help='Подробный режим вывода') diff --git a/tests/custom/10times_hello.txt b/tests/custom/10times_hello.txt new file mode 100644 index 0000000..d60b9b4 --- /dev/null +++ b/tests/custom/10times_hello.txt @@ -0,0 +1,3 @@ +Напиши только 10 раз слово "привет" через пробел +============== +привет привет привет привет привет привет привет привет привет привет