1- import torch
21import json
32import re
43from typing import Any , Dict , List , Optional
5- from transformers import AutoTokenizer , AutoModelForCausalLM , pipeline
64from opik import track
75from pydantic import BaseModel
86from loguru import logger
97import sys
8+
109logger .remove ()
1110logger .add (sys .stderr , level = "INFO" )
12- # Giả lập OpikBaseModel
1311
1412from opik .evaluation .models import OpikBaseModel
1513
14+
1615class OpikHFModel (OpikBaseModel ):
1716 def __init__ (self , model_id : str , ** kwargs ):
17+ # Lazy import: chỉ import khi thực sự cần (evaluator-local extra)
18+ try :
19+ import torch
20+ from transformers import AutoTokenizer , AutoModelForCausalLM , pipeline as hf_pipeline
21+ self ._torch = torch
22+ self ._AutoTokenizer = AutoTokenizer
23+ self ._AutoModelForCausalLM = AutoModelForCausalLM
24+ self ._hf_pipeline = hf_pipeline
25+ except ImportError :
26+ raise ImportError (
27+ "Thiếu dependencies cho local model. "
28+ "Hãy cài: pip install 'xfmr-zem[evaluator-local]'"
29+ )
30+
1831 super ().__init__ (model_name = model_id )
1932 self .model_id = model_id
2033 self .max_new_tokens = kwargs .get ("max_new_tokens" , 512 )
21- # Temperature thấp để model tập trung vào logic chấm điểm, không sáng tạo lung tung
22- self .temperature = kwargs .get ("temperature" , 0.01 )
34+ self .temperature = kwargs .get ("temperature" , 0.01 )
2335 self .device = kwargs .get ("device" , "auto" )
2436 self ._load_model ()
25-
37+
2638 @track (name = "load_hf_model" )
2739 def _load_model (self ):
28- self .tokenizer = AutoTokenizer .from_pretrained (self .model_id )
40+ self .tokenizer = self . _AutoTokenizer .from_pretrained (self .model_id )
2941 if self .tokenizer .pad_token is None :
3042 self .tokenizer .pad_token = self .tokenizer .eos_token
31-
32- self .model = AutoModelForCausalLM .from_pretrained (
43+
44+ self .model = self . _AutoModelForCausalLM .from_pretrained (
3345 self .model_id ,
34- torch_dtype = torch .float16 ,
46+ torch_dtype = self . _torch .float16 ,
3547 device_map = self .device
3648 )
37-
38- self .generator = pipeline (
49+
50+ self .generator = self . _hf_pipeline (
3951 "text-generation" ,
4052 model = self .model ,
4153 tokenizer = self .tokenizer ,
@@ -45,89 +57,72 @@ def _load_model(self):
4557 def _extract_json_string (self , text : str ) -> str :
4658 """Cắt lấy phần JSON từ output"""
4759 text = text .strip ()
48- # Regex tìm ```json { ... } ```
4960 match = re .search (r"```(?:json)?\s*(\{.*?\})\s*```" , text , re .DOTALL )
50- if match : return match .group (1 )
51-
52- # Fallback: tìm { }
61+ if match :
62+ return match .group (1 )
5363 start , end = text .find ("{" ), text .rfind ("}" )
54- if start != - 1 and end != - 1 : return text [start : end + 1 ]
64+ if start != - 1 and end != - 1 :
65+ return text [start : end + 1 ]
5566 return text
5667
5768 @logger .catch (reraise = True )
5869 @track (name = "hf_generate_string" )
5970 def generate_string (self , input : str , response_format : Any = None , ** kwargs : Any ) -> str :
60- # 1. LOG INPUT
61- # Quan sát xem input thực tế nhận vào là gì (đã có schema hay chưa?)
6271 logger .info (f"\n --- [GENERATE START] ---\n " )
6372
64- # 2. GENERATION
65- # Không can thiệp sửa input nữa, chỉ setup tham số chạy
6673 params = {
6774 "max_new_tokens" : kwargs .get ("max_new_tokens" , self .max_new_tokens ),
68- "temperature" : kwargs .get ("temperature" , 0.01 if response_format else self .temperature ), # Temp thấp nếu cần JSON
75+ "temperature" : kwargs .get ("temperature" , 0.01 if response_format else self .temperature ),
6976 "do_sample" : True ,
7077 "return_full_text" : False
7178 }
7279
7380 try :
74- # Gọi model sinh text
7581 response = self .generator (input , ** params )
7682 raw_text = response [0 ]["generated_text" ].strip ()
77-
78- # Nếu không yêu cầu format đặc biệt, trả về luôn
83+
7984 if not response_format :
8085 return raw_text
8186
82- # 3. CHECK FORMAT (VALIDATION)
8387 logger .info ("--- [VALIDATING JSON] ---" )
84-
85- # Bước A: Extract JSON từ text (lọc rác markdown)
8688 json_str = self ._extract_json_string (raw_text )
87-
88- # Bước B: Parse & Validate
89- data = json .loads (json_str ) # Thử parse JSON thuần
90-
91- # Nếu có Pydantic Model, validate chặt chẽ kiểu dữ liệu
89+ data = json .loads (json_str )
90+
9291 if isinstance (response_format , type ) and issubclass (response_format , BaseModel ):
9392 logger .info (f"Validating against Pydantic Model: { response_format .__name__ } " )
9493 validated_obj = response_format .model_validate (data )
95-
96- # Thành công!
9794 final_json = validated_obj .model_dump_json ()
9895 logger .info ("VALIDATION SUCCESS ✅" )
9996 return final_json
100-
101- # Nếu chỉ là dict schema thường
97+
10298 logger .info ("VALIDATION SUCCESS (Dict) ✅" )
10399 return json .dumps (data , ensure_ascii = False )
104100
105101 except json .JSONDecodeError as e :
106102 logger .error (f"❌ JSON PARSE ERROR: { e } \n Bad String: { json_str } " )
107- # Trả về lỗi dạng JSON để Opik ghi nhận thay vì crash chương trình
108103 return f'{{"error": "JSONDecodeError", "details": "{ str (e )} ", "raw_output": "{ raw_text } "}}'
109-
104+
110105 except Exception as e :
111106 logger .error (f"❌ GENERATION/VALIDATION ERROR: { e } " )
112107 return f'{{"error": "RuntimeError", "details": "{ str (e )} "}}'
113108
114109 @track (name = "hf_generate_provider_response" )
115110 def generate_provider_response (self , messages : List [Dict [str , Any ]], ** kwargs : Any ) -> Any :
116- # Chuyển messages thành prompt string cơ bản
117111 prompt = "\n " .join ([f"{ m .get ('role' ,'' ).title ()} : { m .get ('content' ,'' )} " for m in messages ])
118112 prompt += "\n Assistant:"
119-
113+
120114 generated_text = self .generate_string (
121- prompt ,
122- response_format = kwargs .pop ("response_format" , None ),
115+ prompt ,
116+ response_format = kwargs .pop ("response_format" , None ),
123117 ** kwargs
124118 )
125-
119+
126120 return {
127121 "choices" : [{"message" : {"role" : "assistant" , "content" : generated_text }}],
128122 "model" : self .model_id
129123 }
130124
125+
131126class OpikLocalFactory :
132127 @staticmethod
133128 def create_model (provider : str , model_id : str , ** kwargs ) -> Any :
0 commit comments