-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodel_metadata.py
More file actions
60 lines (49 loc) · 2.66 KB
/
model_metadata.py
File metadata and controls
60 lines (49 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding: utf-8 -*-
import json
import os
import sys
from dataiku.doctor.posttraining.model_information_handler import PredictionModelInformationHandler
def get_model_handler(model, version_id=None):
"""
model: a dku saved model returned by dataiku.Model(model_id)
version_id: if None, the active one is chosen
"""
saved_model_version_id = _get_saved_model_version_id(model, version_id)
return _get_model_info_handler(saved_model_version_id)
def _get_saved_model_version_id(model, version_id=None):
model_def = model.get_definition()
if version_id is None:
version_id = model_def.get('activeVersion')
saved_model_version_id = 'S-{0}-{1}-{2}'.format(model_def.get('projectKey'), model_def.get('id'), version_id)
return saved_model_version_id
def _get_model_info_handler(saved_model_version_id):
infos = saved_model_version_id.split("-")
if len(infos) != 4 or infos[0] != "S":
raise ValueError("Invalid saved model id")
pkey = infos[1]
model_id = infos[2]
version_id = infos[3]
datadir_path = os.environ['DIP_HOME']
version_folder = os.path.join(datadir_path, "saved_models", pkey, model_id, "versions", version_id)
# Loading and resolving paths in split_desc
split_folder = os.path.join(version_folder, "split")
with open(os.path.join(split_folder, "split.json")) as split_file:
split_desc = json.load(split_file)
path_field_names = ["trainPath", "testPath", "fullPath"]
for field_name in path_field_names:
if split_desc.get(field_name, None) is not None:
split_desc[field_name] = os.path.join(split_folder, split_desc[field_name])
with open(os.path.join(version_folder, "core_params.json")) as core_params_file:
core_params = json.load(core_params_file)
try:
return PredictionModelInformationHandler(split_desc, core_params, version_folder, version_folder)
except Exception as e:
from future.utils import raise_
if "ordinal not in range(128)" in str(e):
raise_(Exception, "The plugin only supports python3, cannot load a python2 model. Original error: {}".format(e), sys.exc_info()[2])
elif str(e) == "non-string names in Numpy dtype unpickling":
raise_(Exception, "The plugin is using a python2 code-env, cannot load a python3 model. Original error: {}".format(e), sys.exc_info()[2])
elif str(e) == "Using saved models in python recipes is limited to models trained using the python engine":
raise_(Exception, "The plugin does not support Clustering model.", sys.exc_info()[2])
else:
raise_(Exception, "Fail to load saved model: {}".format(e), sys.exc_info()[2])