@@ -190,43 +190,35 @@ def _find_name(self, name, loc):
190
190
"name '{name}' is not bound to anything" , {"name" :name }, loc )
191
191
self .engine .process (diag )
192
192
193
- def visit_root (self , node ):
193
+ # Visitors that replace node with a typed node
194
+ #
195
+ def visit_Module (self , node ):
194
196
extractor = LocalExtractor (env_stack = self .env_stack , engine = self .engine )
195
197
extractor .visit (node )
196
- self .env_stack .append (extractor .typing_env )
197
198
199
+ node = asttyped .ModuleT (
200
+ typing_env = extractor .typing_env , globals_in_scope = extractor .global_ ,
201
+ body = node .body , loc = node .loc )
198
202
return self .visit (node )
199
203
200
- # Visitors that replace node with a typed node
201
- #
202
- def visit_arg (self , node ):
203
- return asttyped .argT (type = self ._find_name (node .arg , node .loc ),
204
- arg = node .arg , annotation = self .visit (node .annotation ),
205
- arg_loc = node .arg_loc , colon_loc = node .colon_loc , loc = node .loc )
206
-
207
204
def visit_FunctionDef (self , node ):
208
205
extractor = LocalExtractor (env_stack = self .env_stack , engine = self .engine )
209
206
extractor .visit (node )
210
207
211
- self .env_stack .append (extractor .typing_env )
212
-
213
208
node = asttyped .FunctionDefT (
214
209
typing_env = extractor .typing_env , globals_in_scope = extractor .global_ ,
215
210
return_type = types .TVar (),
216
-
217
211
name = node .name , args = node .args , returns = node .returns ,
218
212
body = node .body , decorator_list = node .decorator_list ,
219
213
keyword_loc = node .keyword_loc , name_loc = node .name_loc ,
220
214
arrow_loc = node .arrow_loc , colon_loc = node .colon_loc , at_locs = node .at_locs ,
221
215
loc = node .loc )
216
+ return self .visit (node )
222
217
223
- old_function , self .function = self .function , node
224
- self .generic_visit (node )
225
- self .function = old_function
226
-
227
- self .env_stack .pop ()
228
-
229
- return node
218
+ def visit_arg (self , node ):
219
+ return asttyped .argT (type = self ._find_name (node .arg , node .loc ),
220
+ arg = node .arg , annotation = self .visit (node .annotation ),
221
+ arg_loc = node .arg_loc , colon_loc = node .colon_loc , loc = node .loc )
230
222
231
223
def visit_Num (self , node ):
232
224
if isinstance (node .n , int ):
@@ -346,6 +338,26 @@ def visit_UnaryOpT(self, node):
346
338
self .engine .process (diag )
347
339
return node
348
340
341
+ def visit_ModuleT (self , node ):
342
+ self .env_stack .append (node .typing_env )
343
+
344
+ node = self .generic_visit (node )
345
+
346
+ self .env_stack .pop ()
347
+
348
+ return node
349
+
350
+ def visit_FunctionDefT (self , node ):
351
+ self .env_stack .append (node .typing_env )
352
+ old_function , self .function = self .function , node
353
+
354
+ node = self .generic_visit (node )
355
+
356
+ self .function = old_function
357
+ self .env_stack .pop ()
358
+
359
+ return node
360
+
349
361
def visit_Assign (self , node ):
350
362
node = self .generic_visit (node )
351
363
if len (node .targets ) > 1 :
@@ -455,7 +467,7 @@ def process_diagnostic(diag):
455
467
456
468
buf = source .Buffer ("" .join (fileinput .input ()), os .path .basename (fileinput .filename ()))
457
469
parsed , comments = parse_buffer (buf , engine = engine )
458
- typed = Inferencer (engine = engine ).visit_root (parsed )
470
+ typed = Inferencer (engine = engine ).visit (parsed )
459
471
printer = Printer (buf )
460
472
printer .visit (typed )
461
473
for comment in comments :
0 commit comments