ai-benchmark/src/benchmarks/base.py

101 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
from typing import Dict, Any, List
from abc import ABC, abstractmethod
from models.ollama_client import OllamaClient
class Benchmark(ABC):
"""Базовый класс для всех бенчмарков."""
def __init__(self, name: str):
"""
Инициализация бенчмарка.
Args:
name: Название бенчмарка
"""
self.name = name
self.logger = logging.getLogger(__name__)
@abstractmethod
def load_test_data(self) -> List[Dict[str, Any]]:
"""
Загрузка тестовых данных.
Returns:
Список тестовых случаев
"""
pass
@abstractmethod
def evaluate(self, model_response: str, expected: str) -> float:
"""
Оценка качества ответа модели.
Args:
model_response: Ответ от модели
expected: Ожидаемый ответ
Returns:
Метрика качества (0-1)
"""
pass
def run(self, ollama_client: OllamaClient, model_name: str) -> Dict[str, Any]:
"""
Запуск бенчмарка.
Args:
ollama_client: Клиент для работы с Ollama
model_name: Название модели
Returns:
Результаты бенчмарка
"""
test_cases = self.load_test_data()
results = []
for i, test_case in enumerate(test_cases, 1):
try:
self.logger.info(f"Running test case {i}/{len(test_cases)} for {self.name}")
# Замер времени
start_time = time.time()
# Получение ответа от модели
prompt = test_case['prompt']
model_response = ollama_client.generate(
model=model_name,
prompt=prompt,
options={'temperature': 0.7}
)
# Замер времени
latency = time.time() - start_time
# Оценка качества
score = self.evaluate(model_response, test_case['expected'])
results.append({
'test_case': test_case['name'],
'prompt': prompt,
'expected': test_case['expected'],
'model_response': model_response,
'score': score,
'latency': latency
})
except Exception as e:
self.logger.error(f"Error in test case {i}: {e}")
results.append({
'test_case': test_case['name'],
'error': str(e)
})
return {
'benchmark_name': self.name,
'total_tests': len(test_cases),
'successful_tests': len([r for r in results if 'score' in r]),
'results': results
}