Skip to content

Commit 5a3c9b8

Browse files
committed
fix(argilla): refactor create_dataset to fix indexing logic in Argilla 2.x
1 parent dbd5120 commit 5a3c9b8

1 file changed

Lines changed: 48 additions & 32 deletions

File tree

src/xfmr_zem/servers/argilla/server.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -73,52 +73,68 @@ def create_dataset(
7373
) -> Dict[str, Any]:
7474
"""
7575
Tạo hoặc lấy dataset trên Argilla server.
76-
77-
Args:
78-
data: Không dùng, cho phép Zem auto-chain từ step trước
79-
name: Tên dataset
80-
workspace: Workspace chứa dataset
81-
fields: List cấu hình fields. Mỗi field là dict:
82-
{"name": "text", "type": "text"} hoặc {"name": "chat", "type": "chat"}
83-
questions: List cấu hình questions. Ví dụ:
84-
{"name": "label", "type": "label", "labels": ["pos", "neg"]}
85-
{"name": "rating", "type": "rating", "values": [1,2,3,4,5]}
86-
{"name": "comment", "type": "text"}
87-
guidelines: Hướng dẫn cho annotators
88-
api_url: URL Argilla server (override env/config)
89-
api_key: API key (override env/config)
90-
91-
Returns:
92-
{"status": "created"|"exists", "dataset_name": str, "workspace": str}
76+
Fix lỗi logic indexing trong built-in server 0.3.7.
9377
"""
78+
import argilla as rg
9479
client = _get_client(api_url, api_key)
9580

96-
# Defaults nếu không truyền
81+
# 1. Kiểm tra tồn tại với logic đúng cho Argilla 2.x
82+
try:
83+
existing = client.datasets(name=name, workspace=workspace)
84+
if existing:
85+
logger.info(f"Dataset '{name}' đã tồn tại.")
86+
return {"status": "exists", "dataset_name": name, "workspace": workspace}
87+
except Exception as e:
88+
logger.debug(f"Lỗi khi check dataset tồn tại: {e}")
89+
90+
# 2. Xây dựng config nếu chưa có
9791
if fields is None:
9892
fields = [{"name": "text", "type": "text"}]
9993
if questions is None:
10094
questions = [
10195
{"name": "label", "type": "label", "labels": ["positive", "negative", "neutral"]}
10296
]
10397

104-
try:
105-
existing = client.datasets(name=name, workspace=workspace)
106-
if existing:
107-
logger.info(f"Dataset '{name}' đã tồn tại.")
108-
return {"status": "exists", "dataset_name": name, "workspace": workspace}
109-
except Exception:
110-
pass
98+
# 3. Tạo mới (sử dụng logic fixed tương tự DatasetFactory)
99+
built_fields = []
100+
for f in fields:
101+
ftype = f.get("type", "text").lower()
102+
fname = f["name"]
103+
ftitle = f.get("title", fname)
104+
built_fields.append(rg.TextField(name=fname, title=ftitle))
105+
106+
built_questions = []
107+
for q in questions:
108+
qtype = q.get("type", "label").lower()
109+
qname = q["name"]
110+
qtitle = q.get("title", qname)
111+
if qtype == "label":
112+
built_questions.append(rg.LabelQuestion(
113+
name=qname, title=qtitle, labels=q.get("labels", []),
114+
required=q.get("required", True)
115+
))
116+
# Có thể thêm các loại question khác nếu cần
117+
118+
settings = rg.Settings(
119+
fields=built_fields,
120+
questions=built_questions,
121+
guidelines=guidelines,
122+
)
111123

112-
DatasetFactory.create_or_get(
113-
client=client,
124+
dataset = rg.Dataset(
114125
name=name,
115126
workspace=workspace,
116-
fields=fields,
117-
questions=questions,
118-
guidelines=guidelines,
127+
settings=settings,
128+
client=client,
119129
)
120-
121-
return {"status": "created", "dataset_name": name, "workspace": workspace}
130+
dataset.create()
131+
logger.info(f"Đã tạo dataset '{name}'")
132+
133+
return {
134+
"status": "created",
135+
"dataset_name": name,
136+
"workspace": workspace
137+
}
122138

123139

124140
# ─────────────────────────────────────────────────────────────────────────────

0 commit comments

Comments
 (0)