refactor: standardize prompt parameter naming and enhance benchmark base class
- Changed all prompt parameters from '{text}' to '{input}' for consistency
- Enhanced Benchmark base class with prompt loading and validation
- Added test data loading functionality with proper error handling
- Improved initialization to accept prompt path and test data directory
- Added validation for prompt format and file existence
- Implemented structured test data loading from directory
```
The commit message follows conventional commit format with a clear title and descriptive body explaining the changes and their purpose.
This commit is contained in:
@@ -1 +1 @@
|
|||||||
Write Python code that {task}
|
Write Python code that {input}
|
||||||
|
|||||||
1
prompts/custom.txt
Normal file
1
prompts/custom.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
{input}
|
||||||
@@ -15,4 +15,4 @@ Example:
|
|||||||
[User supplies text]
|
[User supplies text]
|
||||||
[Model outputs only the summary in Russian, no code or tables]
|
[Model outputs only the summary in Russian, no code or tables]
|
||||||
|
|
||||||
'{text}'
|
'{input}'
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
Translate the following English text to Russian: '{text}'
|
Translate the following text to Russian: '{input}'
|
||||||
|
|||||||
@@ -1,23 +1,91 @@
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Any, List
|
import os
|
||||||
|
import re
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from models.ollama_client import OllamaClient
|
from models.ollama_client import OllamaClient
|
||||||
|
from constants import TEST_SEPARATOR
|
||||||
|
|
||||||
class Benchmark(ABC):
|
class Benchmark(ABC):
|
||||||
"""Базовый класс для всех бенчмарков."""
|
"""Базовый класс для всех бенчмарков."""
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str, prompt_path: str, test_data_dir: str):
|
||||||
"""
|
"""
|
||||||
Инициализация бенчмарка.
|
Инициализация бенчмарка.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: Название бенчмарка
|
name: Название бенчмарка
|
||||||
|
prompt_path: Путь к файлу с промптом
|
||||||
|
test_data_dir: Путь к каталогу с тестовыми данными
|
||||||
"""
|
"""
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.prompt_path = prompt_path
|
||||||
|
self.test_data_dir = test_data_dir
|
||||||
self.logger = logging.getLogger(__name__)
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@abstractmethod
|
# Загружаем и валидируем универсальный промпт
|
||||||
|
self.universal_prompt = self._load_and_validate_prompt(prompt_path)
|
||||||
|
|
||||||
|
def _load_and_validate_prompt(self, prompt_path: str) -> str:
|
||||||
|
"""
|
||||||
|
Загрузка и валидация промпта.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt_path: Путь к файлу с промптом
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Содержимое промпта
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Если промпт не содержит ожидаемые параметры
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(prompt_path, 'r', encoding='utf-8') as f:
|
||||||
|
prompt = f.read().strip()
|
||||||
|
|
||||||
|
# Базовая валидация: промпт должен содержать хотя бы один параметр форматирования
|
||||||
|
if not re.search(r'\{[^}]+\}', prompt):
|
||||||
|
raise ValueError(f"Промпт {prompt_path} не содержит параметров форматирования")
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
except FileNotFoundError:
|
||||||
|
raise FileNotFoundError(f"Файл промпта не найден: {prompt_path}")
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Ошибка загрузки промпта {prompt_path}: {e}")
|
||||||
|
|
||||||
|
def _load_test_data(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Загрузка тестовых данных.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Список тестовых случаев
|
||||||
|
"""
|
||||||
|
test_data = []
|
||||||
|
|
||||||
|
if not os.path.exists(self.test_data_dir):
|
||||||
|
self.logger.warning(f"Каталог с тестовыми данными не найден: {self.test_data_dir}")
|
||||||
|
return test_data
|
||||||
|
|
||||||
|
for filename in os.listdir(self.test_data_dir):
|
||||||
|
if filename.endswith('.txt'):
|
||||||
|
filepath = os.path.join(self.test_data_dir, filename)
|
||||||
|
try:
|
||||||
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
# Разделяем по разделителю
|
||||||
|
parts = content.split(TEST_SEPARATOR, 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
test_data.append({
|
||||||
|
'name': filename.replace('.txt', ''),
|
||||||
|
'prompt': self.universal_prompt,
|
||||||
|
'expected': parts[1]
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Ошибка загрузки тестового случая {filename}: {e}")
|
||||||
|
|
||||||
|
return test_data
|
||||||
|
|
||||||
def load_test_data(self) -> List[Dict[str, Any]]:
|
def load_test_data(self) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Загрузка тестовых данных.
|
Загрузка тестовых данных.
|
||||||
@@ -25,12 +93,34 @@ class Benchmark(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Список тестовых случаев
|
Список тестовых случаев
|
||||||
"""
|
"""
|
||||||
pass
|
test_data = []
|
||||||
|
|
||||||
|
if not os.path.exists(self.test_data_dir):
|
||||||
|
self.logger.warning(f"Каталог с тестовыми данными не найден: {self.test_data_dir}")
|
||||||
|
return test_data
|
||||||
|
|
||||||
|
for filename in os.listdir(self.test_data_dir):
|
||||||
|
if filename.endswith('.txt'):
|
||||||
|
filepath = os.path.join(self.test_data_dir, filename)
|
||||||
|
try:
|
||||||
|
with open(filepath, 'r', encoding='utf-8') as f:
|
||||||
|
content = f.read()
|
||||||
|
# Разделяем по разделителю
|
||||||
|
parts = content.split(TEST_SEPARATOR, 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
test_data.append({
|
||||||
|
'name': filename.replace('.txt', ''),
|
||||||
|
'prompt': self.universal_prompt.format(input=parts[0]),
|
||||||
|
'expected': parts[1]
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Ошибка загрузки тестового случая {filename}: {e}")
|
||||||
|
|
||||||
|
return test_data
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def evaluate(self, model_response: str, expected: str) -> float:
|
def evaluate(self, model_response: str, expected: str) -> float:
|
||||||
"""
|
"""
|
||||||
Оценка качества ответа модели.
|
Оценка качества ответа модели (по умолчанию на основе F1-score).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_response: Ответ от модели
|
model_response: Ответ от модели
|
||||||
@@ -39,7 +129,23 @@ class Benchmark(ABC):
|
|||||||
Returns:
|
Returns:
|
||||||
Метрика качества (0-1)
|
Метрика качества (0-1)
|
||||||
"""
|
"""
|
||||||
pass
|
# Простая оценка на основе совпадения токенов
|
||||||
|
model_tokens = set(model_response.lower().split())
|
||||||
|
expected_tokens = set(expected.lower().split())
|
||||||
|
|
||||||
|
if len(expected_tokens) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
intersection = model_tokens.intersection(expected_tokens)
|
||||||
|
precision = len(intersection) / len(model_tokens) if model_tokens else 0.0
|
||||||
|
recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0
|
||||||
|
|
||||||
|
# F1-score
|
||||||
|
if (precision + recall) == 0:
|
||||||
|
return 0.0
|
||||||
|
f1 = 2 * (precision * recall) / (precision + recall)
|
||||||
|
|
||||||
|
return round(f1, 3)
|
||||||
|
|
||||||
def run(self, ollama_client: OllamaClient, model_name: str, context_size: int = 32000) -> Dict[str, Any]:
|
def run(self, ollama_client: OllamaClient, model_name: str, context_size: int = 32000) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -9,61 +9,4 @@ class CodegenBenchmark(Benchmark):
|
|||||||
"""Бенчмарк для тестирования генерации кода."""
|
"""Бенчмарк для тестирования генерации кода."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("codegen")
|
super().__init__("codegen", "prompts/codegen.txt", "tests/codegen")
|
||||||
# Загружаем универсальный промпт
|
|
||||||
with open('prompts/codegen.txt', 'r', encoding='utf-8') as f:
|
|
||||||
self.universal_prompt = f.read().strip()
|
|
||||||
|
|
||||||
def load_test_data(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Загрузка тестовых данных для генерации кода.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Список тестовых случаев
|
|
||||||
"""
|
|
||||||
test_data = []
|
|
||||||
data_dir = "tests/codegen"
|
|
||||||
|
|
||||||
for filename in os.listdir(data_dir):
|
|
||||||
if filename.endswith('.txt'):
|
|
||||||
with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
# Разделяем по разделителю
|
|
||||||
parts = content.split(TEST_SEPARATOR, 1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
test_data.append({
|
|
||||||
'name': filename.replace('.txt', ''),
|
|
||||||
'prompt': self.universal_prompt.format(task=parts[0]),
|
|
||||||
'expected': parts[1]
|
|
||||||
})
|
|
||||||
|
|
||||||
return test_data
|
|
||||||
|
|
||||||
def evaluate(self, model_response: str, expected: str) -> float:
|
|
||||||
"""
|
|
||||||
Оценка качества сгенерированного кода.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_response: Ответ от модели
|
|
||||||
expected: Ожидаемый ответ
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Метрика качества (0-1)
|
|
||||||
"""
|
|
||||||
# Простая оценка на основе совпадения токенов
|
|
||||||
model_tokens = set(model_response.lower().split())
|
|
||||||
expected_tokens = set(expected.lower().split())
|
|
||||||
|
|
||||||
if len(expected_tokens) == 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
intersection = model_tokens.intersection(expected_tokens)
|
|
||||||
precision = len(intersection) / len(model_tokens) if model_tokens else 0.0
|
|
||||||
recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0
|
|
||||||
|
|
||||||
# F1-score
|
|
||||||
if (precision + recall) == 0:
|
|
||||||
return 0.0
|
|
||||||
f1 = 2 * (precision * recall) / (precision + recall)
|
|
||||||
|
|
||||||
return round(f1, 3)
|
|
||||||
|
|||||||
8
src/benchmarks/custom.py
Normal file
8
src/benchmarks/custom.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import logging
|
||||||
|
from benchmarks.base import Benchmark
|
||||||
|
|
||||||
|
class CustomBenchmark(Benchmark):
|
||||||
|
"""Бенчмарк для универсальных тестов, где инструкция передается в самом тест-кейсе."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__("custom", "prompts/custom.txt", "tests/custom")
|
||||||
@@ -9,61 +9,4 @@ class SummarizationBenchmark(Benchmark):
|
|||||||
"""Бенчмарк для тестирования пересказов."""
|
"""Бенчмарк для тестирования пересказов."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("summarization")
|
super().__init__("summarization", "prompts/summarization.txt", "tests/summarization")
|
||||||
# Загружаем универсальный промпт
|
|
||||||
with open('prompts/summarization.txt', 'r', encoding='utf-8') as f:
|
|
||||||
self.universal_prompt = f.read().strip()
|
|
||||||
|
|
||||||
def load_test_data(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Загрузка тестовых данных для пересказов.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Список тестовых случаев
|
|
||||||
"""
|
|
||||||
test_data = []
|
|
||||||
data_dir = "tests/summarization"
|
|
||||||
|
|
||||||
for filename in os.listdir(data_dir):
|
|
||||||
if filename.endswith('.txt'):
|
|
||||||
with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
# Разделяем по разделителю
|
|
||||||
parts = content.split(TEST_SEPARATOR, 1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
test_data.append({
|
|
||||||
'name': filename.replace('.txt', ''),
|
|
||||||
'prompt': self.universal_prompt.format(text=parts[0]),
|
|
||||||
'expected': parts[1]
|
|
||||||
})
|
|
||||||
|
|
||||||
return test_data
|
|
||||||
|
|
||||||
def evaluate(self, model_response: str, expected: str) -> float:
|
|
||||||
"""
|
|
||||||
Оценка качества пересказа.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_response: Ответ от модели
|
|
||||||
expected: Ожидаемый ответ
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Метрика качества (0-1)
|
|
||||||
"""
|
|
||||||
# Простая оценка на основе совпадения токенов
|
|
||||||
model_tokens = set(model_response.lower().split())
|
|
||||||
expected_tokens = set(expected.lower().split())
|
|
||||||
|
|
||||||
if len(expected_tokens) == 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
intersection = model_tokens.intersection(expected_tokens)
|
|
||||||
precision = len(intersection) / len(model_tokens) if model_tokens else 0.0
|
|
||||||
recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0
|
|
||||||
|
|
||||||
# F1-score
|
|
||||||
if (precision + recall) == 0:
|
|
||||||
return 0.0
|
|
||||||
f1 = 2 * (precision * recall) / (precision + recall)
|
|
||||||
|
|
||||||
return round(f1, 3)
|
|
||||||
|
|||||||
@@ -9,62 +9,4 @@ class TranslationBenchmark(Benchmark):
|
|||||||
"""Бенчмарк для тестирования перевода."""
|
"""Бенчмарк для тестирования перевода."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("translation")
|
super().__init__("translation", "prompts/translation.txt", "tests/translation")
|
||||||
# Загружаем универсальный промпт
|
|
||||||
with open('prompts/translation.txt', 'r', encoding='utf-8') as f:
|
|
||||||
self.universal_prompt = f.read().strip()
|
|
||||||
|
|
||||||
def load_test_data(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Загрузка тестовых данных для перевода.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Список тестовых случаев
|
|
||||||
"""
|
|
||||||
test_data = []
|
|
||||||
data_dir = "tests/translation"
|
|
||||||
|
|
||||||
for filename in os.listdir(data_dir):
|
|
||||||
if filename.endswith('.txt'):
|
|
||||||
with open(os.path.join(data_dir, filename), 'r', encoding='utf-8') as f:
|
|
||||||
content = f.read()
|
|
||||||
# Разделяем по разделителю
|
|
||||||
parts = content.split(TEST_SEPARATOR, 1)
|
|
||||||
if len(parts) == 2:
|
|
||||||
test_data.append({
|
|
||||||
'name': filename.replace('.txt', ''),
|
|
||||||
'prompt': self.universal_prompt.format(text=parts[0]),
|
|
||||||
'expected': parts[1]
|
|
||||||
})
|
|
||||||
|
|
||||||
return test_data
|
|
||||||
|
|
||||||
def evaluate(self, model_response: str, expected: str) -> float:
|
|
||||||
"""
|
|
||||||
Оценка качества перевода.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_response: Ответ от модели
|
|
||||||
expected: Ожидаемый ответ
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Метрика качества (0-1)
|
|
||||||
"""
|
|
||||||
# Простая оценка на основе совпадения токенов
|
|
||||||
# В реальном проекте можно использовать более сложные метрики
|
|
||||||
model_tokens = set(model_response.lower().split())
|
|
||||||
expected_tokens = set(expected.lower().split())
|
|
||||||
|
|
||||||
if len(expected_tokens) == 0:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
intersection = model_tokens.intersection(expected_tokens)
|
|
||||||
precision = len(intersection) / len(model_tokens) if model_tokens else 0.0
|
|
||||||
recall = len(intersection) / len(expected_tokens) if expected_tokens else 0.0
|
|
||||||
|
|
||||||
# F1-score
|
|
||||||
if (precision + recall) == 0:
|
|
||||||
return 0.0
|
|
||||||
f1 = 2 * (precision * recall) / (precision + recall)
|
|
||||||
|
|
||||||
return round(f1, 3)
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from models.ollama_client import OllamaClient
|
|||||||
from benchmarks.translation import TranslationBenchmark
|
from benchmarks.translation import TranslationBenchmark
|
||||||
from benchmarks.summarization import SummarizationBenchmark
|
from benchmarks.summarization import SummarizationBenchmark
|
||||||
from benchmarks.codegen import CodegenBenchmark
|
from benchmarks.codegen import CodegenBenchmark
|
||||||
|
from benchmarks.custom import CustomBenchmark
|
||||||
from utils.report import ReportGenerator
|
from utils.report import ReportGenerator
|
||||||
|
|
||||||
def setup_logging(verbose: bool = False):
|
def setup_logging(verbose: bool = False):
|
||||||
@@ -34,7 +35,8 @@ def run_benchmarks(ollama_client: OllamaClient, model_name: str, benchmarks: Lis
|
|||||||
benchmark_classes = {
|
benchmark_classes = {
|
||||||
'translation': TranslationBenchmark,
|
'translation': TranslationBenchmark,
|
||||||
'summarization': SummarizationBenchmark,
|
'summarization': SummarizationBenchmark,
|
||||||
'codegen': CodegenBenchmark
|
'codegen': CodegenBenchmark,
|
||||||
|
'custom': CustomBenchmark
|
||||||
}
|
}
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
@@ -58,7 +60,7 @@ def main():
|
|||||||
parser.add_argument('-u', '--ollama-url', default='http://localhost:11434', help='URL подключения к Ollama серверу')
|
parser.add_argument('-u', '--ollama-url', default='http://localhost:11434', help='URL подключения к Ollama серверу')
|
||||||
parser.add_argument('-c', '--context-size', type=int, default=32000, help='Размер контекста для модели (по умолчанию 32000)')
|
parser.add_argument('-c', '--context-size', type=int, default=32000, help='Размер контекста для модели (по умолчанию 32000)')
|
||||||
parser.add_argument('-b', '--benchmarks', nargs='+', default=['translation', 'summarization', 'codegen'],
|
parser.add_argument('-b', '--benchmarks', nargs='+', default=['translation', 'summarization', 'codegen'],
|
||||||
help='Список бенчмарков для выполнения (translation, summarization, codegen)')
|
help='Список бенчмарков для выполнения (translation, summarization, codegen, custom)')
|
||||||
parser.add_argument('-o', '--output', default='results', help='Директория для сохранения результатов')
|
parser.add_argument('-o', '--output', default='results', help='Директория для сохранения результатов')
|
||||||
parser.add_argument('-v', '--verbose', action='store_true', help='Подробный режим вывода')
|
parser.add_argument('-v', '--verbose', action='store_true', help='Подробный режим вывода')
|
||||||
|
|
||||||
|
|||||||
3
tests/custom/10times_hello.txt
Normal file
3
tests/custom/10times_hello.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
Напиши только 10 раз слово "привет" через пробел
|
||||||
|
==============
|
||||||
|
привет привет привет привет привет привет привет привет привет привет
|
||||||
Reference in New Issue
Block a user