Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dace/libraries/onnx/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def from_onnx_proto(cls, onnx_proto):
def from_json(cls, json, context=None):

constructor_args = {
name: prop.from_json(json[name] if name in json else prop.default)
name: prop.from_json(json[name] if name in json else prop.default, context=context)
for name, prop in cls.__properties__.items()
}
return cls(**constructor_args)
Expand Down
61 changes: 38 additions & 23 deletions dace/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import dace.subsets as sbs
import dace
import dace.serialize
from packaging.version import parse as parse_version
from dace import symbolic
from dace.symbolic import pystr_to_symbolic
from dace.dtypes import DebugInfo, typeclass
Expand Down Expand Up @@ -41,6 +42,20 @@ def _coerce_symbolic_property_value(value):
return pystr_to_symbolic(value, simplify=False)


def _symbolic_deserializer(value: str, context=None) -> symbolic.SymbolicType:
"""
A backwards compatibility deserializer for symbolic properties. If the version of the
context is less than ``2.0.0a4``, it will use the old ``pystr_to_symbolic`` deserializer.
Otherwise, it will use the new ``symbolic.deserialize_symbolic`` function.
"""
version = (context or {}).get("version", None)
if version is None:
raise TypeError("Context must contain version information for symbolic deserialization")
if version is None or parse_version(version) < parse_version("2.0.0a4"):
return pystr_to_symbolic(value, simplify=False)
return symbolic.deserialize_symbolic(value)


###############################################################################
# External interface to guarantee correct usage
###############################################################################
Expand Down Expand Up @@ -515,18 +530,18 @@ def from_string(self, s):
else:
return list(s)

def from_json(self, data, sdfg=None):
def from_json(self, data, context=None):
if data is None:
return data
if not isinstance(data, list):
raise TypeError('ListProperty expects a list input, got %s' % data)
if _is_symbolic_type(self.element_type):
return [symbolic.deserialize_symbolic(elem) for elem in data]
return [_symbolic_deserializer(elem, context=context) for elem in data]
if _is_symbolic_converter(self.element_type):
return [symbolic.deserialize_symbolic(elem) for elem in data]
return [_symbolic_deserializer(elem, context=context) for elem in data]
# If element knows how to convert itself, let it
if hasattr(self.element_type, "from_json"):
return [self.element_type.from_json(elem) for elem in data]
return [self.element_type.from_json(elem, context) for elem in data]
# Type-checks (casts) to the element type
return list(map(self.element_type, data))

Expand Down Expand Up @@ -554,12 +569,12 @@ def to_json(self, hist):
return None
return [elem.to_json() if elem is not None else None for elem in hist]

def from_json(self, data, sdfg=None):
def from_json(self, data, context=None):
if data is None:
return data
if not isinstance(data, list):
raise TypeError('TransformationHistProperty expects a list input, got %s' % data)
return [dace.serialize.from_json(elem) for elem in data]
return [dace.serialize.from_json(elem, context=context) for elem in data]


class DictProperty(Property):
Expand Down Expand Up @@ -657,7 +672,7 @@ def to_json(self, d):
def from_string(s):
return dict(s)

def from_json(self, data, sdfg=None):
def from_json(self, data, context=None):
if data is None:
return data
if not isinstance(data, dict):
Expand All @@ -669,17 +684,17 @@ def from_json(self, data, sdfg=None):

def _convert_key(key):
if _is_symbolic_type(self.key_type):
return symbolic.deserialize_symbolic(key)
return _symbolic_deserializer(key, context)
if _is_symbolic_converter(self.key_type):
return symbolic.deserialize_symbolic(key)
return self.key_type.from_json(key, sdfg) if key_json else self.key_type(key)
return _symbolic_deserializer(key, context)
return self.key_type.from_json(key, context) if key_json else self.key_type(key)

def _convert_value(value):
if _is_symbolic_type(self.value_type):
return symbolic.deserialize_symbolic(value)
return _symbolic_deserializer(value, context)
if _is_symbolic_converter(self.value_type):
return symbolic.deserialize_symbolic(value)
return self.value_type.from_json(value, sdfg) if value_json else self.value_type(value)
return _symbolic_deserializer(value, context)
return self.value_type.from_json(value, context) if value_json else self.value_type(value)

return {_convert_key(k): _convert_value(v) for k, v in data.items()}

Expand Down Expand Up @@ -1053,7 +1068,7 @@ def to_json(self):
return ret

@staticmethod
def from_json(tmp, sdfg=None):
def from_json(tmp, context=None):
if tmp is None:
return None
if isinstance(tmp, CodeBlock):
Expand Down Expand Up @@ -1204,8 +1219,8 @@ def to_json(self, val):
except AttributeError:
return SubsetProperty.to_string(val)

def from_json(self, val, sdfg=None):
return dace.serialize.from_json(val)
def from_json(self, val, context=None):
return dace.serialize.from_json(val, context)


class SymbolicProperty(Property):
Expand Down Expand Up @@ -1238,10 +1253,10 @@ def to_json(self, val):
return None
return symbolic.serialize_symbolic(val)

def from_json(self, val, sdfg=None):
def from_json(self, val, context=None):
if val is None:
return None
return symbolic.deserialize_symbolic(val)
return _symbolic_deserializer(val, context=context)


class DataProperty(Property):
Expand Down Expand Up @@ -1342,10 +1357,10 @@ def to_json(self, obj):
return None
return [symbolic.serialize_symbolic(o) for o in obj]

def from_json(self, d, sdfg=None):
def from_json(self, d, context=None):
if d is None:
return None
return tuple([symbolic.deserialize_symbolic(m) for m in d])
return tuple([_symbolic_deserializer(m, context=context) for m in d])

def __set__(self, obj, val):
if isinstance(val, list):
Expand Down Expand Up @@ -1419,7 +1434,7 @@ def from_json(obj, context=None):
return TypeClassProperty.from_string(obj)
elif isinstance(obj, dict):
# Let the deserializer handle this
return dace.serialize.from_json(obj)
return dace.serialize.from_json(obj, context=context)
else:
raise TypeError("Cannot parse type from: {}".format(obj))

Expand Down Expand Up @@ -1460,7 +1475,7 @@ def from_json(obj, context=None):
return NestedDataClassProperty.from_string(obj)
elif isinstance(obj, dict):
# Let the deserializer handle this
return dace.serialize.from_json(obj)
return dace.serialize.from_json(obj, context=context)
else:
raise TypeError("Cannot parse type from: {}".format(obj))

Expand Down Expand Up @@ -1493,7 +1508,7 @@ def to_json(self, obj):
return None
return obj.dict()

def from_json(self, d, sdfg=None):
def from_json(self, d, context=None):
if d is None:
return None
return self.dtype.parse_obj(d)
10 changes: 7 additions & 3 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ def _strip_transformation_history(json_obj: Any):
@classmethod
def from_json(cls, json_obj, context=None):
context = context or {'sdfg': None}
context['version'] = json_obj.get('dace_version', context.get('version'))
_type = json_obj['type']
if _type != cls.__name__:
raise TypeError("Class type mismatch")
Expand All @@ -699,13 +700,16 @@ def from_json(cls, json_obj, context=None):
edges = json_obj['edges']

if 'constants_prop' in attrs:
constants_prop = dace.serialize.loads(dace.serialize.dumps(attrs['constants_prop']))
constants_prop = dace.serialize.loads(dace.serialize.dumps(attrs['constants_prop']), context=context)
else:
constants_prop = None

ret = SDFG(name=attrs['name'], constants=constants_prop, parent=context['sdfg'])

dace.serialize.set_properties_from_json(ret, json_obj, ignore_properties={'constants_prop', 'name', 'hash'})
dace.serialize.set_properties_from_json(ret,
json_obj,
context=context,
ignore_properties={'constants_prop', 'name', 'hash'})

nodelist = []
for n in nodes:
Expand All @@ -717,7 +721,7 @@ def from_json(cls, json_obj, context=None):
nodelist.append(block)

for e in edges:
e = dace.serialize.from_json(e)
e = dace.serialize.from_json(e, context=context)
ret.add_edge(nodelist[int(e.src)], nodelist[int(e.dst)], e.data)

if 'start_block' in json_obj:
Expand Down
17 changes: 9 additions & 8 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ def from_json(cls, json_obj, context=None):

ret = cls(label=json_obj['label'], sdfg=context['sdfg'])

dace.serialize.set_properties_from_json(ret, json_obj)
dace.serialize.set_properties_from_json(ret, json_obj, context=context)

return ret

Expand Down Expand Up @@ -1536,14 +1536,14 @@ def _open_scope(scope_entry: Optional[nd.Node], authority: Dict[str, dtypes.type
return ret

@classmethod
def from_json(cls, json_obj, context={'sdfg': None}, pre_ret=None):
def from_json(cls, json_obj, context=None, pre_ret=None):
""" Loads the node properties, label and type into a dict.

:param json_obj: The object containing information about this node.
NOTE: This may not be a string!
:return: An SDFGState instance constructed from the passed data
"""

context = context or {'sdfg': None}
_type = json_obj['type']
if _type != cls.__name__:
raise Exception("Class type mismatch")
Expand All @@ -1558,7 +1558,8 @@ def from_json(cls, json_obj, context={'sdfg': None}, pre_ret=None):
rec_ci = {
'sdfg': context['sdfg'],
'sdfg_state': ret,
'callback': context['callback'] if 'callback' in context else None
'callback': context.get('callback'),
'version': context.get('version'),
}
serialize.set_properties_from_json(ret, json_obj, rec_ci)

Expand Down Expand Up @@ -3126,7 +3127,7 @@ def from_json(cls, json_obj, context=None):

ret = cls(label=json_obj['label'], sdfg=context['sdfg'])

dace.serialize.set_properties_from_json(ret, json_obj)
dace.serialize.set_properties_from_json(ret, json_obj, context=context)

nodelist = []
for n in nodes:
Expand All @@ -3138,7 +3139,7 @@ def from_json(cls, json_obj, context=None):
nodelist.append(block)

for e in edges:
e = dace.serialize.from_json(e)
e = dace.serialize.from_json(e, context=context)
ret.add_edge(nodelist[int(e.src)], nodelist[int(e.dst)], e.data)

if 'start_block' in json_obj:
Expand Down Expand Up @@ -3958,11 +3959,11 @@ def from_json(cls, json_obj, context=None):

ret = cls(label=json_obj['label'], sdfg=context['sdfg'])

dace.serialize.set_properties_from_json(ret, json_obj)
dace.serialize.set_properties_from_json(ret, json_obj, context=context)

for condition, region in json_obj['branches']:
if condition is not None:
ret.add_branch(CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context))
ret.add_branch(CodeBlock.from_json(condition, context), ControlFlowRegion.from_json(region, context))
else:
ret.add_branch(None, ControlFlowRegion.from_json(region, context))
return ret
Expand Down
5 changes: 3 additions & 2 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def a2s(obj):

@staticmethod
def from_json(obj, context=None):
from dace.properties import _symbolic_deserializer # Avoid circular import
if not isinstance(obj, dict):
raise TypeError("Expected dict, got {}".format(type(obj)))
if obj['type'] != 'Range':
Expand All @@ -350,8 +351,8 @@ def from_json(obj, context=None):
tuples = []

for r in ranges:
tuples.append((symbolic.deserialize_symbolic(r['start']), symbolic.deserialize_symbolic(r['end']),
symbolic.deserialize_symbolic(r['step']), symbolic.deserialize_symbolic(r['tile'])))
tuples.append((_symbolic_deserializer(r['start'], context), _symbolic_deserializer(r['end'], context),
_symbolic_deserializer(r['step'], context), _symbolic_deserializer(r['tile'], context)))

return Range(tuples)

Expand Down
8 changes: 6 additions & 2 deletions dace/transformation/interstate/loop_unroll.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
from typing import List, Optional, Union

from dace import sdfg as sd, symbolic, serialize
from dace import sdfg as sd, symbolic, serialize, version
from dace.properties import Property, make_properties
from dace.sdfg import InterstateEdge, utils as sdutil
from dace.sdfg.nodes import NestedSDFG
Expand Down Expand Up @@ -141,7 +141,11 @@ def instantiate_loop_iteration(self,

for block in loop.nodes():
# Using to/from JSON is faster for copying blocks than deep copying.
new_block = serialize.from_json(serialize.to_json(block), context={'sdfg': graph.sdfg})
new_block = serialize.from_json(serialize.to_json(block),
context={
'sdfg': graph.sdfg,
'version': version.__version__,
})
assert block not in block_map
block_map[block] = new_block
# The JSON copy is created with SDFG context, so replacement can run before insertion.
Expand Down
2 changes: 1 addition & 1 deletion dace/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.0a3'
__version__ = '2.0.0a4'
Loading