Skip to content

Commit 421a637

Browse files
committed
refactor: split [evaluator] into evaluator (cloud) and evaluator-local (HuggingFace local) with lazy imports
1 parent 3c164a9 commit 421a637

4 files changed

Lines changed: 74 additions & 57 deletions

File tree

pyproject.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "xfmr-zem"
3-
version = "0.3.3"
3+
version = "0.3.4"
44
description = "Zem: Unified Data Pipeline Framework (ZenML + NeMo Curator + DataJuicer) for multi-domain processing"
55
readme = "README.md"
66
requires-python = ">=3.10,<3.13"
@@ -119,10 +119,16 @@ voice = [
119119
# ── LLM / Evaluation ──────────────────────────────────────────────────────────
120120
evaluator = [
121121
"opik>=1.10.9",
122+
"litellm>=1.0.0",
123+
]
124+
125+
# evaluator-local: chạy test/eval model local (HuggingFace)
126+
evaluator-local = [
127+
"opik>=1.10.9",
128+
"litellm>=1.0.0",
122129
"transformers>=4.40.0",
123130
"torch>=2.1.0",
124131
"accelerate>=0.25.0",
125-
"litellm>=1.0.0",
126132
]
127133

128134
# ── Web UI ─────────────────────────────────────────────────────────────────────

src/xfmr_zem/servers/evaluator/factory/eval_engines/local.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,53 @@
1-
import torch
21
import json
32
import re
43
from typing import Any, Dict, List, Optional
5-
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
64
from opik import track
75
from pydantic import BaseModel
86
from loguru import logger
97
import sys
8+
109
logger.remove()
1110
logger.add(sys.stderr, level="INFO")
12-
# Giả lập OpikBaseModel
1311

1412
from opik.evaluation.models import OpikBaseModel
1513

14+
1615
class OpikHFModel(OpikBaseModel):
1716
def __init__(self, model_id: str, **kwargs):
17+
# Lazy import: chỉ import khi thực sự cần (evaluator-local extra)
18+
try:
19+
import torch
20+
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline as hf_pipeline
21+
self._torch = torch
22+
self._AutoTokenizer = AutoTokenizer
23+
self._AutoModelForCausalLM = AutoModelForCausalLM
24+
self._hf_pipeline = hf_pipeline
25+
except ImportError:
26+
raise ImportError(
27+
"Thiếu dependencies cho local model. "
28+
"Hãy cài: pip install 'xfmr-zem[evaluator-local]'"
29+
)
30+
1831
super().__init__(model_name=model_id)
1932
self.model_id = model_id
2033
self.max_new_tokens = kwargs.get("max_new_tokens", 512)
21-
# Temperature thấp để model tập trung vào logic chấm điểm, không sáng tạo lung tung
22-
self.temperature = kwargs.get("temperature", 0.01)
34+
self.temperature = kwargs.get("temperature", 0.01)
2335
self.device = kwargs.get("device", "auto")
2436
self._load_model()
25-
37+
2638
@track(name="load_hf_model")
2739
def _load_model(self):
28-
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
40+
self.tokenizer = self._AutoTokenizer.from_pretrained(self.model_id)
2941
if self.tokenizer.pad_token is None:
3042
self.tokenizer.pad_token = self.tokenizer.eos_token
31-
32-
self.model = AutoModelForCausalLM.from_pretrained(
43+
44+
self.model = self._AutoModelForCausalLM.from_pretrained(
3345
self.model_id,
34-
torch_dtype=torch.float16,
46+
torch_dtype=self._torch.float16,
3547
device_map=self.device
3648
)
37-
38-
self.generator = pipeline(
49+
50+
self.generator = self._hf_pipeline(
3951
"text-generation",
4052
model=self.model,
4153
tokenizer=self.tokenizer,
@@ -45,89 +57,72 @@ def _load_model(self):
4557
def _extract_json_string(self, text: str) -> str:
4658
"""Cắt lấy phần JSON từ output"""
4759
text = text.strip()
48-
# Regex tìm ```json { ... } ```
4960
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
50-
if match: return match.group(1)
51-
52-
# Fallback: tìm { }
61+
if match:
62+
return match.group(1)
5363
start, end = text.find("{"), text.rfind("}")
54-
if start != -1 and end != -1: return text[start : end + 1]
64+
if start != -1 and end != -1:
65+
return text[start: end + 1]
5566
return text
5667

5768
@logger.catch(reraise=True)
5869
@track(name="hf_generate_string")
5970
def generate_string(self, input: str, response_format: Any = None, **kwargs: Any) -> str:
60-
# 1. LOG INPUT
61-
# Quan sát xem input thực tế nhận vào là gì (đã có schema hay chưa?)
6271
logger.info(f"\n--- [GENERATE START] ---\n")
6372

64-
# 2. GENERATION
65-
# Không can thiệp sửa input nữa, chỉ setup tham số chạy
6673
params = {
6774
"max_new_tokens": kwargs.get("max_new_tokens", self.max_new_tokens),
68-
"temperature": kwargs.get("temperature", 0.01 if response_format else self.temperature), # Temp thấp nếu cần JSON
75+
"temperature": kwargs.get("temperature", 0.01 if response_format else self.temperature),
6976
"do_sample": True,
7077
"return_full_text": False
7178
}
7279

7380
try:
74-
# Gọi model sinh text
7581
response = self.generator(input, **params)
7682
raw_text = response[0]["generated_text"].strip()
77-
78-
# Nếu không yêu cầu format đặc biệt, trả về luôn
83+
7984
if not response_format:
8085
return raw_text
8186

82-
# 3. CHECK FORMAT (VALIDATION)
8387
logger.info("--- [VALIDATING JSON] ---")
84-
85-
# Bước A: Extract JSON từ text (lọc rác markdown)
8688
json_str = self._extract_json_string(raw_text)
87-
88-
# Bước B: Parse & Validate
89-
data = json.loads(json_str) # Thử parse JSON thuần
90-
91-
# Nếu có Pydantic Model, validate chặt chẽ kiểu dữ liệu
89+
data = json.loads(json_str)
90+
9291
if isinstance(response_format, type) and issubclass(response_format, BaseModel):
9392
logger.info(f"Validating against Pydantic Model: {response_format.__name__}")
9493
validated_obj = response_format.model_validate(data)
95-
96-
# Thành công!
9794
final_json = validated_obj.model_dump_json()
9895
logger.info("VALIDATION SUCCESS ✅")
9996
return final_json
100-
101-
# Nếu chỉ là dict schema thường
97+
10298
logger.info("VALIDATION SUCCESS (Dict) ✅")
10399
return json.dumps(data, ensure_ascii=False)
104100

105101
except json.JSONDecodeError as e:
106102
logger.error(f"❌ JSON PARSE ERROR: {e}\nBad String: {json_str}")
107-
# Trả về lỗi dạng JSON để Opik ghi nhận thay vì crash chương trình
108103
return f'{{"error": "JSONDecodeError", "details": "{str(e)}", "raw_output": "{raw_text}"}}'
109-
104+
110105
except Exception as e:
111106
logger.error(f"❌ GENERATION/VALIDATION ERROR: {e}")
112107
return f'{{"error": "RuntimeError", "details": "{str(e)}"}}'
113108

114109
@track(name="hf_generate_provider_response")
115110
def generate_provider_response(self, messages: List[Dict[str, Any]], **kwargs: Any) -> Any:
116-
# Chuyển messages thành prompt string cơ bản
117111
prompt = "\n".join([f"{m.get('role','').title()}: {m.get('content','')}" for m in messages])
118112
prompt += "\nAssistant:"
119-
113+
120114
generated_text = self.generate_string(
121-
prompt,
122-
response_format=kwargs.pop("response_format", None),
115+
prompt,
116+
response_format=kwargs.pop("response_format", None),
123117
**kwargs
124118
)
125-
119+
126120
return {
127121
"choices": [{"message": {"role": "assistant", "content": generated_text}}],
128122
"model": self.model_id
129123
}
130124

125+
131126
class OpikLocalFactory:
132127
@staticmethod
133128
def create_model(provider: str, model_id: str, **kwargs) -> Any:

src/xfmr_zem/servers/evaluator/factory/models.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
from typing import Any, Optional
2-
from transformers import AutoModelForCausalLM, AutoTokenizer
3-
import torch
42
import opik
53

4+
65
class HuggingFaceLM:
76
def __init__(self, model_id: str):
7+
# Lazy import: chỉ import khi thực sự cần (evaluator-local extra)
8+
try:
9+
from transformers import AutoModelForCausalLM, AutoTokenizer
10+
import torch
11+
except ImportError:
12+
raise ImportError(
13+
"Thiếu dependencies cho local model. "
14+
"Hãy cài: pip install 'xfmr-zem[evaluator-local]'"
15+
)
16+
817
self.model_id = model_id
918
print(f"Loading HuggingFace Model: {model_id}")
1019
self.model = AutoModelForCausalLM.from_pretrained(
@@ -13,7 +22,7 @@ def __init__(self, model_id: str):
1322
device_map="auto"
1423
)
1524
self.tokenizer = AutoTokenizer.from_pretrained(
16-
model_id,
25+
model_id,
1726
torch_dtype="auto",
1827
device_map="auto"
1928
)
@@ -31,27 +40,28 @@ def generate(self, input_text: str, system_prompt: Optional[str] = None) -> str:
3140
{"role": "system", "content": system_prompt},
3241
{"role": "user", "content": input_text}
3342
]
34-
43+
3544
text = self.tokenizer.apply_chat_template(
3645
messages,
3746
tokenize=False,
3847
add_generation_prompt=True
3948
)
40-
49+
4150
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
4251

4352
generated_ids = self.model.generate(
4453
**model_inputs,
4554
max_new_tokens=512
4655
)
47-
56+
4857
generated_ids = [
4958
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
5059
]
5160

5261
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
5362
return response
5463

64+
5565
class ModelFactory:
5666
@staticmethod
5767
def get_model(engine_type: str, model_id: str) -> Any:

uv.lock

Lines changed: 11 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)