Skip to content

Commit

Permalink
Allow accessing attributes of embedded host objects.
Browse files Browse the repository at this point in the history
  • Loading branch information
whitequark committed Aug 27, 2015
1 parent 422208a commit cb22526
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 40 deletions.
121 changes: 106 additions & 15 deletions artiq/compiler/embedding.py
Expand Up @@ -6,12 +6,13 @@
"""

import os, re, linecache, inspect
from collections import OrderedDict
from collections import OrderedDict, defaultdict

from pythonparser import ast, source, diagnostic, parse_buffer

from . import types, builtins, asttyped, prelude
from .transforms import ASTTypedRewriter, Inferencer, IntMonomorphizer
from .validators import MonomorphismValidator


class ObjectMap:
Expand All @@ -34,10 +35,11 @@ def retrieve(self, obj_key):
return self.forward_map[obj_key]

class ASTSynthesizer:
def __init__(self, type_map, expanded_from=None):
def __init__(self, type_map, value_map, expanded_from=None):
self.source = ""
self.source_buffer = source.Buffer(self.source, "<synthesized>")
self.type_map, self.expanded_from = type_map, expanded_from
self.type_map, self.value_map = type_map, value_map
self.expanded_from = expanded_from

def finalize(self):
self.source_buffer.source = self.source
Expand Down Expand Up @@ -82,6 +84,11 @@ def quote(self, value):
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
else:
quote_loc = self._add('`')
repr_loc = self._add(repr(value))
unquote_loc = self._add('`')
loc = quote_loc.join(unquote_loc)

if isinstance(value, type):
typ = value
else:
Expand All @@ -98,16 +105,14 @@ def quote(self, value):

self.type_map[typ] = instance_type, constructor_type

quote_loc = self._add('`')
repr_loc = self._add(repr(value))
unquote_loc = self._add('`')

if isinstance(value, type):
self.value_map[constructor_type].append((value, loc))
return asttyped.QuoteT(value=value, type=constructor_type,
loc=quote_loc.join(unquote_loc))
loc=loc)
else:
self.value_map[instance_type].append((value, loc))
return asttyped.QuoteT(value=value, type=instance_type,
loc=quote_loc.join(unquote_loc))
loc=loc)

def call(self, function_node, args, kwargs):
"""
Expand Down Expand Up @@ -151,14 +156,16 @@ def call(self, function_node, args, kwargs):
loc=name_loc.join(end_loc))

class StitchingASTTypedRewriter(ASTTypedRewriter):
def __init__(self, engine, prelude, globals, host_environment, quote_function, type_map):
def __init__(self, engine, prelude, globals, host_environment, quote_function,
type_map, value_map):
super().__init__(engine, prelude)
self.globals = globals
self.env_stack.append(self.globals)

self.host_environment = host_environment
self.quote_function = quote_function
self.type_map = type_map
self.value_map = value_map

def visit_Name(self, node):
typ = super()._try_find_name(node.id)
Expand All @@ -180,7 +187,9 @@ def visit_Name(self, node):

else:
# It's just a value. Quote it.
synthesizer = ASTSynthesizer(expanded_from=node.loc, type_map=self.type_map)
synthesizer = ASTSynthesizer(expanded_from=node.loc,
type_map=self.type_map,
value_map=self.value_map)
node = synthesizer.quote(value)
synthesizer.finalize()
return node
Expand All @@ -190,6 +199,83 @@ def visit_Name(self, node):
node.loc)
self.engine.process(diag)

class StitchingInferencer(Inferencer):
def __init__(self, engine, type_map, value_map):
super().__init__(engine)
self.type_map, self.value_map = type_map, value_map

def visit_AttributeT(self, node):
self.generic_visit(node)
object_type = node.value.type.find()

# The inferencer can only observe types, not values; however,
# when we work with host objects, we have to get the values
# somewhere, since host interpreter does not have types.
# Since we have categorized every host object we quoted according to
# its type, we now interrogate every host object we have to ensure
# that we can successfully serialize the value of the attribute we
# are now adding at the code generation stage.
#
# FIXME: We perform exhaustive checks of every known host object every
# time an attribute access is visited, which is potentially quadratic.
# This is done because it is simpler than performing the checks only when:
# * a previously unknown attribute is encountered,
# * a previously unknown host object is encountered;
# which would be the optimal solution.
for object_value, object_loc in self.value_map[object_type]:
if not hasattr(object_value, node.attr):
note = diagnostic.Diagnostic("note",
"attribute accessed here", {},
node.loc)
diag = diagnostic.Diagnostic("error",
"host object does not have an attribute '{attr}'",
{"attr": node.attr},
object_loc, notes=[note])
self.engine.process(diag)
return

# Figure out what ARTIQ type does the value of the attribute have.
# We do this by quoting it, as if to serialize. This has some
# overhead (i.e. synthesizing a source buffer), but has the advantage
# of having the host-to-ARTIQ mapping code in only one place and
# also immediately getting proper diagnostics on type errors.
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
ast = synthesizer.quote(getattr(object_value, node.attr))
synthesizer.finalize()

def proxy_diagnostic(diag):
note = diagnostic.Diagnostic("note",
"expanded from here while trying to infer a type for an"
" attribute '{attr}' of a host object",
{"attr": node.attr},
node.loc)
diag.notes.append(note)

self.engine.process(diag)

proxy_engine = diagnostic.Engine()
proxy_engine.process = proxy_diagnostic
Inferencer(engine=proxy_engine).visit(ast)
IntMonomorphizer(engine=proxy_engine).visit(ast)
MonomorphismValidator(engine=proxy_engine).visit(ast)

if node.attr not in object_type.attributes:
# We just figured out what the type should be. Add it.
object_type.attributes[node.attr] = ast.type
elif object_type.attributes[node.attr] != ast.type:
# Does this conflict with an earlier guess?
printer = types.TypePrinter()
diag = diagnostic.Diagnostic("error",
"host object has an attribute of type {typea}, which is"
" different from previously inferred type {typeb}",
{"typea": printer.name(ast.type),
"typeb": printer.name(object_type.attributes[node.attr])},
object_loc)
self.engine.process(diag)

super().visit_AttributeT(node)

class Stitcher:
def __init__(self, engine=None):
if engine is None:
Expand All @@ -206,9 +292,11 @@ def __init__(self, engine=None):

self.object_map = ObjectMap()
self.type_map = {}
self.value_map = defaultdict(lambda: [])

def finalize(self):
inferencer = Inferencer(engine=self.engine)
inferencer = StitchingInferencer(engine=self.engine,
type_map=self.type_map, value_map=self.value_map)

# Iterate inference to fixed point.
self.inference_finished = False
Expand Down Expand Up @@ -262,7 +350,8 @@ def _quote_embedded_function(self, function):
asttyped_rewriter = StitchingASTTypedRewriter(
engine=self.engine, prelude=self.prelude,
globals=self.globals, host_environment=host_environment,
quote_function=self._quote_function, type_map=self.type_map)
quote_function=self._quote_function,
type_map=self.type_map, value_map=self.value_map)
return asttyped_rewriter.visit(function_node)

def _function_loc(self, function):
Expand Down Expand Up @@ -324,7 +413,8 @@ def _type_of_param(self, function, loc, param, is_syscall):
# This is tricky, because the default value might not have
# a well-defined type in APython.
# In this case, we bail out, but mention why we do it.
synthesizer = ASTSynthesizer(type_map=self.type_map)
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
ast = synthesizer.quote(param.default)
synthesizer.finalize()

Expand Down Expand Up @@ -442,7 +532,8 @@ def stitch_call(self, function, args, kwargs):

# We synthesize source code for the initial call so that
# diagnostics would have something meaningful to display to the user.
synthesizer = ASTSynthesizer(type_map=self.type_map)
synthesizer = ASTSynthesizer(type_map=self.type_map,
value_map=self.value_map)
call_node = synthesizer.call(function_node, args, kwargs)
synthesizer.finalize()
self.typedtree.append(call_node)
12 changes: 11 additions & 1 deletion artiq/compiler/transforms/inferencer.py
Expand Up @@ -130,10 +130,20 @@ def makenotes(printer, typea, typeb, loca, locb):
self._unify(node.type, attr_type,
node.loc, None)
else:
if node.attr_loc.source_buffer == node.value.loc.source_buffer:
highlights, notes = [node.value.loc], []
else:
# This happens when the object being accessed is embedded
# from the host program.
note = diagnostic.Diagnostic("note",
"object being accessed", {},
node.value.loc)
highlights, notes = [], [note]

diag = diagnostic.Diagnostic("error",
"type {type} does not have an attribute '{attr}'",
{"type": types.TypePrinter().name(object_type), "attr": node.attr},
node.attr_loc, [node.value.loc])
node.attr_loc, highlights, notes)
self.engine.process(diag)

def _unify_iterable(self, element, collection):
Expand Down
55 changes: 31 additions & 24 deletions artiq/compiler/transforms/llvm_ir_generator.py
Expand Up @@ -278,6 +278,27 @@ def llty_of_type(self, typ, bare=False, for_return=False):
else:
return llty.as_pointer()

def llstr_of_str(self, value, name=None,
linkage="private", unnamed_addr=True):
if isinstance(value, str):
assert "\0" not in value
as_bytes = (value + "\0").encode("utf-8")
else:
as_bytes = value

if name is None:
name = self.llmodule.get_unique_name("str")

llstr = self.llmodule.get_global(name)
if llstr is None:
llstrty = ll.ArrayType(lli8, len(as_bytes))
llstr = ll.GlobalVariable(self.llmodule, llstrty, name)
llstr.global_constant = True
llstr.initializer = ll.Constant(llstrty, bytearray(as_bytes))
llstr.linkage = linkage
llstr.unnamed_addr = unnamed_addr
return llstr.bitcast(llptr)

def llconst_of_const(self, const):
llty = self.llty_of_type(const.type)
if const.value is None:
Expand All @@ -289,33 +310,19 @@ def llconst_of_const(self, const):
elif isinstance(const.value, (int, float)):
return ll.Constant(llty, const.value)
elif isinstance(const.value, (str, bytes)):
if isinstance(const.value, str):
assert "\0" not in const.value
as_bytes = (const.value + "\0").encode("utf-8")
else:
as_bytes = const.value

if ir.is_exn_typeinfo(const.type):
# Exception typeinfo; should be merged with identical others
name = "__artiq_exn_" + const.value
linkage = "linkonce"
unnamed_addr = False
else:
# Just a string
name = self.llmodule.get_unique_name("str")
name = None
linkage = "private"
unnamed_addr = True

llconst = self.llmodule.get_global(name)
if llconst is None:
llstrty = ll.ArrayType(lli8, len(as_bytes))
llconst = ll.GlobalVariable(self.llmodule, llstrty, name)
llconst.global_constant = True
llconst.initializer = ll.Constant(llstrty, bytearray(as_bytes))
llconst.linkage = linkage
llconst.unnamed_addr = unnamed_addr

return llconst.bitcast(llptr)
return self.llstr_of_str(const.value, name=name,
linkage=linkage, unnamed_addr=unnamed_addr)
else:
assert False

Expand Down Expand Up @@ -856,7 +863,7 @@ def ret_error_handler(typ):
tag += self._rpc_tag(fun_type.ret, ret_error_handler)
tag += b"\x00"

lltag = self.llconst_of_const(ir.Constant(tag + b"\x00", builtins.TStr()))
lltag = self.llstr_of_str(tag)

llstackptr = self.llbuilder.call(self.llbuiltin("llvm.stacksave"), [])

Expand Down Expand Up @@ -981,24 +988,24 @@ def _quote(self, value, typ, path):
global_name = "object.{}".format(objectid)
else:
llfields.append(self._quote(getattr(value, attr), typ.attributes[attr],
path + [attr]))
lambda: path() + [attr]))

llvalue = ll.Constant.literal_struct(llfields)
elif builtins.is_none(typ):
assert value is None
return self.llconst_of_const(value)
return ll.Constant.literal_struct([])
elif builtins.is_bool(typ):
assert value in (True, False)
return self.llconst_of_const(value)
return ll.Constant(lli1, value)
elif builtins.is_int(typ):
assert isinstance(value, int)
return self.llconst_of_const(value)
return ll.Constant(ll.IntType(builtins.get_int_width(typ)), value)
elif builtins.is_float(typ):
assert isinstance(value, float)
return self.llconst_of_const(value)
return ll.Constant(lldouble, value)
elif builtins.is_str(typ):
assert isinstance(value, (str, bytes))
return self.llconst_of_const(value)
return self.llstr_of_str(value)
else:
assert False

Expand Down
3 changes: 3 additions & 0 deletions artiq/compiler/types.py
Expand Up @@ -320,6 +320,9 @@ def __eq__(self, other):
def __ne__(self, other):
return not (self == other)

def __hash__(self):
return hash(self.name)

class TBuiltinFunction(TBuiltin):
"""
A type of a builtin function.
Expand Down

0 comments on commit cb22526

Please sign in to comment.