Annotate ast with types for adt codegen
authorLuke Lau <luke_lau@icloud.com>
Tue, 6 Aug 2019 06:58:11 +0000 (07:58 +0100)
committerLuke Lau <luke_lau@icloud.com>
Tue, 6 Aug 2019 06:58:11 +0000 (07:58 +0100)
ast.scm
codegen.scm
tests.scm
typecheck.scm

diff --git a/ast.scm b/ast.scm
index 5064ca554bc8033bdbc0b483ac79249cd3312f73..4ca6eb1e1dc57100a344310ea1494ea2139236e0 100644 (file)
--- a/ast.scm
+++ b/ast.scm
@@ -46,6 +46,7 @@
                     (inner (lambda-body x)))]
     ['if (append (f x)
                 (flat-map inner (cdr x)))]
+    ['closure (flat-map inner (caddr x))]
     [else (f x)]))
 
 (define (ast-find p x)
     [else (p x)]))
 
 (define (let-bindings e)
-  (define (pattern-match x body)
-    (if (eqv? (ast-type x) 'var)
-       (list (cons x body))
-       (let* ([constructor (car x)]
+  (define (pattern-match binding body)
+    (if (eqv? (ast-type binding) 'var)
+       (list (cons binding body))
+       (let* ([constructor (car binding)]
               [destructor (lambda (i) (dtor-name constructor i))])
          (flat-map (lambda (y i)
-                     (pattern-match y (list (destructor i) body)))
-                   (cdr x)
-                   (range 0 (length (cdr x)))))))
+                     (pattern-match y `((,(destructor i) ,@body))))
+                   (cdr binding)
+                   (range 0 (length (cdr binding)))))))
   (flat-map (lambda (x) (pattern-match (car x) (cdr x))) (cadr e)))
 (define let-body cddr)
 
          program))
 
 (define (program-body program)
+  ; hack to have multi-expression bodies
   `(let ()
      ,@(filter (lambda (x) (eqv? (statement-type x) 'expr))
               program)))
 
+
+                                       ; (A ((foo (Int Bool))
+                                       ;     (bar (Bool)))
+
+(define data-layout cdr)
+
                                        ; gets both constructors and destructors
                                        ; (data A (foo Int Bool)
                                        ;         (bar Bool))
                                        ;        |
                                        ;        v
-                                       ; (foo . (abs Int (abs Bool A)))
-                                       ; (foo~0 . (abs A Int)
-                                       ; (foo~1 . (abs A Bool)
-                                       ; (bar . (abs Bool A)
-                                       ; (bar~0 . (abs A Bool)
+                                       ; (foo . (constructor . (abs Int (abs Bool A))))
+                                       ; (foo~0 . (0 . (abs A Int)))
+                                       ; (foo~1 . (1 . (abs A Bool)))
+                                       ; (bar . (constructor . (abs Bool A)))
+                                       ; (bar~0 . (0 . (abs A Bool)))
 
-(define (data-tors data-def)
+(define (data-tors data-layout)
   (define (constructor-type t products)
     (fold-right (lambda (x acc) `(abs ,x ,acc)) t products))
 
   (define (destructor ctor-name prod-type part-type index)
     (let ([name (dtor-name ctor-name index)])
-      (cons name `(abs ,prod-type ,part-type))))
+      (cons name (cons index `(abs ,prod-type ,part-type)))))
   
-  (let ([type-name (cadr data-def)]
-        [ctors (cddr data-def)])
+  (let ([type-name (car data-layout)]
+        [ctors (cdr data-layout)])
     (fold-right
      (lambda (ctor acc)       
        (let* ([ctor-name (car ctor)]
              [products (cdr ctor)]
              
-             [maker (cons ctor-name (constructor-type type-name products))]
+             [maker (cons ctor-name (cons 'constructor (constructor-type type-name products)))]
              
              [dtors (map (lambda (t i) (destructor ctor-name type-name t i))
                          products
         
         (cons maker (append dtors acc))))
      '()
-     ctrs)))
+     ctors)))
+
+                                       ; creates a type environment for a given adt definition
+(define (data-tors-env data-layout)
+  (map (lambda (x) (cons (car x) (cddr x))) (data-tors data-layout)))
 
 (define (dtor-name ctor-name index)
   (string->symbol
 (define lambda-args cadr)
 (define lambda-body caddr)
 
+(define (references prog)
+  (ast-collect
+   (lambda (x)
+     (case (ast-type x)
+       ['var (list x)]
+       [else '()]))
+   prog))
+
+(define (graph bs)
+  (define (go bs orig-bs)
+    (if (null? bs)
+       '(() . ())
+       (let* [(bind (car bs))
+
+              (vert (car bind))
+              (refs (filter ; only count a reference if its a binding
+                     (lambda (x) (assoc x orig-bs))
+                     (references (cdr bind))))
+              (edges (map (lambda (x) (cons vert x))
+                          refs))
+
+              (rest (if (null? (cdr bs))
+                        (cons '() '())
+                        (go (cdr bs) orig-bs)))
+              (total-verts (cons vert (car rest)))
+              (total-edges (append edges (cdr rest)))]
+         (cons total-verts total-edges))))
+  (go bs bs))
+
+(define (successors graph v)
+  (define (go v E)
+    (if (null? E)
+       '()
+       (if (eqv? v (caar E))
+           (cons (cdar E) (go v (cdr E)))
+           (go v (cdr E)))))
+  (go v (cdr graph)))
+
+                                       ; takes in a graph (pair of vertices, edges)
+                                       ; returns a list of strongly connected components
+
+                                       ; ((x y w) . ((x . y) (x . w) (w . x))
+
+                                       ; =>
+                                       ; .->x->y
+                                       ; |  |
+                                       ; |  v
+                                       ; .--w
+
+                                       ; ((x w) (y))
+
+                                       ; this uses tarjan's algorithm, to get reverse
+                                       ; topological sorting for free
+(define (sccs graph)
+  
+  (let* ([indices (make-hash-table)]
+        [lowlinks (make-hash-table)]
+        [on-stack (make-hash-table)]
+        [current 0]
+        [stack '()]
+        [result '()])
+
+    (define (index v)
+      (get-hash-table indices v #f))
+    (define (lowlink v)
+      (get-hash-table lowlinks v #f))
+
+    (letrec
+       ([strong-connect
+         (lambda (v)
+           (begin
+             (put-hash-table! indices v current)
+             (put-hash-table! lowlinks v current)
+             (set! current (+ current 1))
+             (push! stack v)
+             (put-hash-table! on-stack v #t)
+
+             (for-each
+              (lambda (w)
+                (if (not (hashtable-contains? indices w))
+                                       ; successor w has not been visited, recurse
+                    (begin
+                      (strong-connect w)
+                      (put-hash-table! lowlinks
+                                       v
+                                       (min (lowlink v) (lowlink w))))
+                                       ; successor w has been visited
+                    (when (get-hash-table on-stack w #f)
+                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
+              (successors graph v))
+
+             (when (= (index v) (lowlink v))
+               (let ([scc
+                      (let new-scc ()
+                        (let ([w (pop! stack)])
+                          (put-hash-table! on-stack w #f)
+                          (if (eqv? w v)
+                              (list w)
+                              (cons w (new-scc)))))])
+                 (set! result (cons scc result))))))])
+      (for-each
+       (lambda (v)
+        (when (not (hashtable-contains? indices v)) ; v.index == -1
+          (strong-connect v)))
+       (car graph)))
+    result))
+
+
                                        ; utils
+
 (define (range s n)
   (if (= 0 n) '()
       (append (range s (- n 1))
index 10e005f8ad90e840aad396da4a0cdb5032163f6c..2d5f4ac2731800dba6e4041162bc7693be62f90a 100644 (file)
@@ -9,6 +9,35 @@
     (apply printf s)
     (display "\n")))
 
+(define wordsize 8)
+
+(define (type-size type env)
+
+  (define (adt-size adt)
+    (let ([sizes
+          (map (lambda (sum)
+                 (fold-left (lambda (acc x) (+ acc (type-size x)))
+                            wordsize ; one word needed to store tag
+                            (cdr sum)))
+               (cdr adt))])
+      (apply max sizes)))
+  
+  (case type
+    ['Int wordsize]
+    ['Bool wordsize]
+    [else
+     (let ([adt (assoc type (env-adts env))])
+       (if adt
+          (adt-size adt)
+          (error #f "unknown size" type)))]))
+
+                                       ; an environment consists of adt layouts in scope,
+                                       ; and any bound variables.
+                                       ; bound variables are an assoc list with their stack offset
+(define make-env list)
+(define env-adts car)
+(define env-bindings cadr)
+
 (define (codegen-add xs si env)
   (define (go ys)
     (if (null? ys)
     ('linux  (emit "mov $1, %rax"))) ; syscall 1 (write)
   (emit "syscall"))
 
-(define wordsize 8)
-
 (define (codegen-let bindings body si env)
 
                                        ; is this a closure that captures itself?
     (and (eqv? (ast-type expr) 'closure)
         (memv name (caddr expr))))
 
-  (let* ((stack-offsets (map (lambda (name x) ; assoc map of binding name to offset
+
+  ;; (define (emit-scc scc env)
+  ;;   ; acc is a pair of the env and list of touchups
+  ;;   (define (emit-binding acc binding)
+  ;;     (let ([binding-name (car binding)]
+  ;;       [binding-body (cadr binding)]
+
+  ;;       [other-bindings (filter
+  ;;                        (lambda (x) (not (eqv? binding-name x)))
+  ;;                        scc)]
+  ;;       [mutually-recursives
+  ;;        (filter
+  ;;         (lambda (other-binding)
+  ;;           (memv other-binding (references binding-body)))
+  ;;         other-bindings)]
+
+  ;;       [new-touchups (append touchups (cdr acc))])
+
+  ;;                                   ; TODO: assert that the only mutually recursives are closures
+  ;;   (for-each
+  ;;    (lambda (binding)
+  ;;      (when (not (eqv? (ast-type (cadr binding))
+       
+  ;;   (emit "asdf")
+  ;;   (cons new-env new-touchups)
+  ;;   ))
+
+  ;;   (fold-left emit-binding (cons env '()) scc))))
+  
+  (let* ([stack-offsets (map (lambda (name x) ; assoc map of binding name to offset
                               (cons name (- si (* x wordsize))))
                             (map car bindings)
-                            (range 0 (length bindings))))
-        (inner-si (- si (* (length bindings) wordsize)))
+                            (range 0 (length bindings)))]
+        [inner-si (- si (* (length bindings) wordsize))]
 
-        (get-offset (lambda (n) (cdr (assoc n stack-offsets))))
+        [get-offset (lambda (n) (cdr (assoc n stack-offsets)))]
         
         [inner-env
          (fold-left
           (lambda (env comps)
-            (let ([scc-env
+            (let* ([scc-binding-offsets
                     (fold-left
                      (lambda (acc name)
                        (cons (cons name (get-offset name))
                              acc))
-                    env
-                    comps)])
+                     (env-bindings env)
+                     comps)]
+                   [scc-env (make-env (env-adts env) scc-binding-offsets)])
               (for-each 
                (lambda (name)
                  (let ([expr (cadr (assoc name bindings))])
                                        ; codegen-closure realise this!
                        (codegen-expr expr
                                      inner-si
+                                     (make-env
+                                      (env-adts scc-env)
                                       (cons (cons name 'self-captive)
-                                           scc-env))
+                                            (env-bindings scc-env))))
                        (codegen-expr expr inner-si scc-env))
                    (emit "movq %rax, ~a(%rbp)" (get-offset name))))
                comps)
               scc-env))
-          env (reverse (sccs (graph bindings))))])
+          env
+          (reverse (sccs (graph bindings))))])
     
     (for-each (lambda (form)
                (codegen-expr form inner-si inner-env))
              body)))
 
 (define (codegen-var name si env)
-  (when (not (assoc name env))
-    (error #f (format "Variable ~a is not bound" name)))
-  (let ((offset (cdr (assoc name env))))
-    (emit "movq ~a(%rbp), %rax" offset)))
+  (let ([binding (assoc name (env-bindings env))])
+    (if (not binding)
+       (error #f (format "Variable ~a is not bound" name))
+       (emit "movq ~a(%rbp), %rax" (cdr binding)))))
 
 (define cur-lambda 0)
 (define (fresh-lambda)
                                        ; store the captured vars
     (for-each
      (lambda (var-name heap-offset)
-       (let ([stack-offset (cdr (assoc var-name env))])
+       (let ([stack-offset (cdr (assoc var-name (env-bindings env)))])
         (emit "### captive ~a" var-name)
         (if (eqv? stack-offset 'self-captive)
                                        ; captive refers to this closure:
                               (* (- wordsize) (+ 1 i)))
                             (range 0 (length params))))
 
-        (env (map cons params stack-offsets)))
+        [bindings (map cons params stack-offsets)]
+        [env (make-env '() bindings)])
     (emit "~a:" label)
 
     (display "## lambda captives: ")
     (codegen-expr else si env)
     (emit "~a:" exit-label)))
 
+(define (data-tor env e)
+  (and (list? e)
+       (assoc (car e) (flat-map data-tors (env-adts env)))))
+
+(define (codegen-data-tor e si env)
+
+  (define (codegen-destructor tor)
+    (codegen-expr (cadr e) si env)
+    (let ([index (cadr tor)]
+         [products 2]
+         [to-traverse (list-head products index)]
+         [offset (fold-left
+                  (lambda (acc t) (+ acc (type-size t)))
+                  wordsize ; skip tag in first word
+                  to-traverse)])
+      3
+      ))
+  
+  (let ([tor (data-tor env e)]
+       [constructor (eqv? 'constructor (cadr tor))])
+    (if constructor
+       (codegen-constructor tor)
+       (codegen-destructor tor))))
+
 (define (codegen-expr e si env)
   (emit "# ~a" e)
   (case (ast-type e)
        ('= (codegen-eq  (cadr e) (caddr e) si env))
        ('bool->int (codegen-expr (cadr e) si env))
        ('print (codegen-print (cadr e) si env))
-       (else (codegen-call (car e) (cdr e) si env))))
+       (else
+       (if (data-tor env e)
+           (codegen-data-tor e si env)
+           (codegen-call (car e) (cdr e) si env)))))
 
                                        ; this is a builtin being passed around as a variable
     ('builtin (emit "movq $~a, %rax" (builtin-id e)))
   (emit "~a:" (car s))
   (emit "\t.string \"~a\"" (cdr s)))
 
-;; (define (amd64-abi f)
-;;                                     ; preserve registers
-;;   (emit "push %rbp")
-;;   ;; (emit "push %rbx")
-;;   ;; (for-each (lambda (i)
-;;   ;;              (emit (string-append
-;;   ;;                     "push %r"
-;;   ;;                     (number->string i))))
-;;   ;;            '(12 13 14 15))
-
-;;   (emit "movq %rsp, %rbp")              ; set up the base pointer
-
-;;   (f) ; call stuff
-;;                                     ; restore preserved registers
-;;   ;; (for-each (lambda (i)
-;;   ;;              (emit (string-append
-;;   ;;                     "pop %r"
-;;   ;;                     (number->string i))))
-;;   ;;            '(15 14 13 12))
-;;   ;; (emit "pop %rbx")
-;;   (emit "pop %rbp")
-;;   (emit "ret"))
-
                                        ; 24(%rbp) mem arg 1
                                        ; 16(%rbp) mem arg 0          prev frame
                                        ; -----------------------
   (set! cur-lambda 0)
   (let* ([body (program-body program)]
 
+        [data-layouts (map data-layout (program-datas program))]
+        
         (extract-res-0 (extract-strings body))
         (strings (car extract-res-0))
         (extract-res-1 (extract-lambdas (cdr extract-res-0)))
 
     (emit "movq %rsp, %rbp")            ; set up the base pointer
     
-    (codegen-expr xform-prog (- wordsize) '())
+    (codegen-expr xform-prog (- wordsize) (make-env data-layouts '()))
 
                                        ; exit syscall
     (emit "mov %rax, %rdi")
index 30d23f61ed6375c786574a1d18056eb0278b398a..4e14fb6ab8ee3634b1fce5be5166d55fe8455d96 100644 (file)
--- a/tests.scm
+++ b/tests.scm
   (let ((str (read-file "/tmp/test-output.txt")))
     (test str output)))
 
-(test (data-tors '(data A
+(test (data-tors (data-layout '(data A
                         (foo Int Bool)
-                       (bar Bool)))
+                        (bar Bool))))
+      '((foo . (constructor . (abs Int (abs Bool A))))
+       (foo~0 . (0 . (abs A Int)))
+       (foo~1 . (1 . (abs A Bool)))
+       (bar . (constructor . (abs Bool A)))
+       (bar~0 . (0 . (abs A Bool)))))
+
+(test (data-tors-env
+       (data-layout '(data A
+                           (foo Int Bool)
+                           (bar Bool))))
        '((foo . (abs Int (abs Bool A)))
         (foo~0 . (abs A Int))
         (foo~1 . (abs A Bool))
                           (pow 4 2))))
            'Int)
 
+                                       ; ADTs
+
+
 (test-types
  (typecheck
   '((data A
       y)))
  'Int)
 
+
+                                       ; pattern matching
+(test (let-bindings '(let ([(foo x) a]) x))
+      '((x (foo~0 a))))
+
+(test (let-bindings '(let ([x (foo 42)] [(foo y) x]) x))
+      '((x (foo 42))
+       (y (foo~0 x))))
+
+                                       ; type annotations
+
+(test (annotate-types
+       '((let ([x 42]
+              [y (+ 1 x)])
+          (- y x))))
+
+      '((let ()
+         ((let ((x 42 : Int)
+                (y ((((+ : (abs Int (abs Int Int))) (1 : Int)) : (abs Int Int)) (x : Int)) : Int))
+            (((((- : (abs Int (abs Int Int))) (y : Int)) : (abs Int Int)) (x : Int)) : Int))))))
+
 (test-expr '(+ 1 2) 3)
 (test-expr '(bool->int (= 2 0)) 0)
 (test-expr '((lambda (x) ((lambda (y) (+ x y)) 42)) 100) 142)
              (pow 4 2))
           16)
 
-(test-prog-stdout '(let ([f (lambda (n)
-                             (if (= n 0)
-                                 0
-                                 (let ()
-                                   (print "a")
-                                   (g (- n 1)))))]
-                        [g (lambda (m)
-                             (let ()
-                               (print "b")
-                                (f (- m 1))))])
-                        (f 10)) "ababababab")
+                                       ; mutual recursion
+;; (test-prog-stdout '((let ([f (lambda (n)
+;;                           (if (= n 0)
+;;                               0
+;;                               (let ()
+;;                                 (print "a")
+;;                                 (g (- n 1)))))]
+;;                      [g (lambda (m)
+;;                           (let ()
+;;                             (print "b")
+;;                              (f (- m 1))))])
+;;                      (f 10))) "ababababab")
index 821035230affddf4d813074316f5aecaeaa2efc5..cfb30cc4dc40b690c126d08a5a10fa10688faa62 100644 (file)
     (else (error #f "Couldn't find type for builtin" x))))
 
 (define (check-let env x)
-                                       ; takes in the current environment and a scc
-                                       ; returns new environment with scc's types added in
-  (let* ([components (reverse (sccs (graph (let-bindings x))))]
-        [process-component
-         (lambda (acc comps)
+
+  ; acc is a pair of (env . annotated bindings)
+  (define (process-component acc comps)
     (let*
                                        ; create a new env with tvars for each component
                                        ; e.g. scc of (x y)
          (fold-left
           (lambda (acc c)
             (env-insert acc c (fresh-tvar)))
-                  acc comps)]
+          (car acc) comps)]
                                        ; typecheck each component
         [type-results
          (map
                 (if (memv (car x) comps)
                     (cons (car x) (substitute cs (cdr x)))
                     x))
-                      scc-env)])
-             new-env))]
-        [new-env (fold-left process-component env components)])
-    (check new-env (last (let-body x)))))
+              scc-env)]
+
+        [annotated-bindings (append (cdr acc) ; the previous annotated bindings
+                                    (map cons
+                                         comps
+                                         (map caddr type-results)))])
+      (cons new-env annotated-bindings)))
+                                       ; takes in the current environment and a scc
+                                       ; returns new environment with scc's types added in
+  (let* ([components (reverse (sccs (graph (let-bindings x))))]
+        [results (fold-left process-component (cons env '()) components)]
+        [new-env (car results)]
+        [annotated-bindings (cdr results)]
+
+        [body-results (map (lambda (body) (check new-env body)) (let-body x))]
+        [let-type (cadr (last body-results))]
+        [cs (fold-left (lambda (acc cs) (constraint-merge acc cs)) '() (map car body-results))]
+
+        [annotated `((let ,annotated-bindings ,@(map caddr body-results)))])
+    (list cs let-type annotated)))
+
 
+; returns a list (constraints type annotated)
 (define (check env x)
+  (define (make-result cs type)
+    (list cs type `(,x : ,type)))
   ;; (display "check: ")
   ;; (display x)
   ;; (display "\n\t")
   (let
       ((res
        (case (ast-type x)
-         ('int-literal (list '() 'Int))
-         ('bool-literal (list '() 'Bool))
-         ('string-literal (list '() 'String))
-         ('builtin (list '() (builtin-type x)))
+         ('int-literal (make-result '() 'Int))
+         ('bool-literal (make-result '() 'Bool))
+         ('string-literal (make-result '() 'String))
+         ('builtin (make-result '() (builtin-type x)))
 
          ('if
           (let* ((cond-type-res (check env (cadr x)))
                       (constraint-merge (~ (cadr cond-type-res) 'Bool)
                                         (constraint-merge (car else-type-res)
                                                           then-eq-else-cs))))
-                 (return-type (substitute cs (cadr then-type-res))))
-            (list cs return-type)))
+                 (return-type (substitute cs (cadr then-type-res)))          
+                 [annotated `((if ,(caddr cond-type-res)
+                                  ,(caddr then-type-res)
+                                  ,(caddr else-type-res)) : ,return-type)])
+            (list cs return-type annotated)))
          
-         ('var (list '() (env-lookup env x)))
+         ('var (make-result '() (env-lookup env x)))
          ('let (check-let env x))
 
          
          ('lambda
-             (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
-
-                    (body-type-res (check new-env (lambda-body x)))
-                    (cs (car body-type-res))
-                    (subd-env (substitute-env (car body-type-res) new-env))
-                    (arg-type (env-lookup subd-env (lambda-arg x)))
-                    (resolved-arg-type (substitute cs arg-type))]
-               ;; (display "lambda:\n\t")
-               ;; (display prog)
-               ;; (display "\n\t")
-               ;; (display cs)
-               ;; (display "\n\t")
-               ;; (display (format "subd-env: ~a\n" subd-env))
-               ;; (display resolved-arg-type)
-               ;; (newline)
-               (list (car body-type-res)
-                     (list 'abs
-                           resolved-arg-type
-                           (cadr body-type-res)))))
+             (let* ([new-env (env-insert env (lambda-arg x) (fresh-tvar))]
+
+                    [body-type-res (check new-env (lambda-body x))]
+                    [cs (car body-type-res)]
+                    [subd-env (substitute-env (car body-type-res) new-env)]
+                    [arg-type (env-lookup subd-env (lambda-arg x))]
+                    [resolved-arg-type (substitute cs arg-type)]
+
+                    [lambda-type `(abs ,resolved-arg-type ,(cadr body-type-res))]
+
+                    ; TODO: do we need to annotate the lambda argument?
+                    [annotated `(lambda (,(lambda-arg x)) ,(caddr body-type-res))])
+               
+               (list (car body-type-res) ; constraints
+                     lambda-type  ; type
+                     annotated)))
+
          
          ('app ; (f a)
           (if (eqv? (car x) (cadr x))
                                        ; recursive function (f f)
-              (let* [(func-type (env-lookup env (car x)))
-                     (return-type (fresh-tvar))
-                     (other-func-type `(abs ,func-type ,return-type))
-                     (cs (~ func-type other-func-type))
-                     (resolved-return-type (substitute cs return-type))]
-                (list cs resolved-return-type)))
+              (let* ([func-type (env-lookup env (car x))]
+                     [return-type (fresh-tvar)]
+                     [other-func-type `(abs ,func-type ,return-type)]
+                     [cs (~ func-type other-func-type)]
+                     [resolved-return-type (substitute cs return-type)]
+
+                     [annotated `(((,(car x) : ,func-type)
+                                   (,(cadr x) : ,func-type)) : ,resolved-return-type)])
+                (list cs resolved-return-type annotated)))
 
                                        ; regular function
-              (let* ((arg-type-res (check env (cadr x)))
-                     (arg-type (cadr arg-type-res))
-                     (func-type-res (check env (car x)))
-                     (func-type (cadr func-type-res))
+          (let* ([arg-type-res (check env (cadr x))]
+                 [arg-type (cadr arg-type-res)]
+                 [func-type-res (check env (car x))]
+                 [func-type (cadr func-type-res)]
                  
                                        ; f ~ a -> t0
-                     (func-c (~
+                 [func-c (~
                           (substitute (car arg-type-res) func-type)
-                              `(abs ,arg-type ,(fresh-tvar))))
-                     (cs (constraint-merge
+                          `(abs ,arg-type ,(fresh-tvar)))]
+                 [cs (constraint-merge
                       (constraint-merge func-c (car arg-type-res))
-                          (car func-type-res)))
-                     
-                     (resolved-func-type (substitute cs func-type))
-                     (resolved-return-type (caddr resolved-func-type)))
-                ;; (display "app:\n")
-                ;; (display cs)
-                ;; (display "\n")
-                ;; (display func-type)
-                ;; (display "\n")
-                ;; (display resolved-func-type)
-                ;; (display "\n")
-                ;; (display arg-type-res)
-                ;; (display "\n")
+                      (car func-type-res))]
+                 
+                 [resolved-func-type (substitute cs func-type)]
+                 [resolved-return-type (caddr resolved-func-type)]
+
+                 [annotated `((,(caddr func-type-res)
+                               ,(caddr arg-type-res)) : ,resolved-return-type)])
+
             (if (abs? resolved-func-type)
                 (let ((return-type (substitute cs (caddr resolved-func-type))))
-                      (list cs return-type))
+                  (list cs return-type annotated))
                 (error #f "not a function")))))))
     ;; (display "result of ")
     ;; (display x)
     ;; (display "]\n")
     res))
 
+(define (init-adts-env prog)
+  (flat-map data-tors-env (map data-layout (program-datas prog))))
+
                                        ; we typecheck the lambda calculus only (only single arg lambdas)
 (define (typecheck prog)
+  (cadr (check (init-adts-env prog) (normalize (program-body prog)))))
 
-  (let ([init-env (flat-map data-tors (program-datas prog))])
-    (display init-env)
-    (newline)
-    (cadr (check init-env (normalize (program-body prog))))))
+(define (annotate-types prog)
+  (caddr (check (init-adts-env prog) (normalize (program-body prog)))))
 
   
                                        ; returns a list of constraints
 
                                        ; input: a list of binds ((x . y) (y . 3))
                                        ; returns: pair of verts, edges ((x y) . (x . y))
-(define (graph bs)
-  (define (go bs orig-bs)
-    (define (find-refs prog)
-      (ast-collect
-       (lambda (x)
-        (case (ast-type x)
-                                       ; only count a reference if its a binding
-          ['var (if (assoc x orig-bs) (list x) '())]
-          [else '()]))
-       prog))
-    (if (null? bs)
-       '(() . ())
-       (let* [(bind (car bs))
-
-              (vert (car bind))
-              (refs (find-refs (cdr bind)))
-              (edges (map (lambda (x) (cons vert x))
-                          refs))
-
-              (rest (if (null? (cdr bs))
-                        (cons '() '())
-                        (go (cdr bs) orig-bs)))
-              (total-verts (cons vert (car rest)))
-              (total-edges (append edges (cdr rest)))]
-         (cons total-verts total-edges))))
-  (go bs bs))
-
-(define (successors graph v)
-  (define (go v E)
-    (if (null? E)
-       '()
-       (if (eqv? v (caar E))
-           (cons (cdar E) (go v (cdr E)))
-           (go v (cdr E)))))
-  (go v (cdr graph)))
-
-                                       ; takes in a graph (pair of vertices, edges)
-                                       ; returns a list of strongly connected components
-
-                                       ; ((x y w) . ((x . y) (x . w) (w . x))
-
-                                       ; =>
-                                       ; .->x->y
-                                       ; |  |
-                                       ; |  v
-                                       ; .--w
-
-                                       ; ((x w) (y))
-
-                                       ; this uses tarjan's algorithm, to get reverse
-                                       ; topological sorting for free
-(define (sccs graph)
-  
-  (let* ([indices (make-hash-table)]
-        [lowlinks (make-hash-table)]
-        [on-stack (make-hash-table)]
-        [current 0]
-        [stack '()]
-        [result '()])
-
-    (define (index v)
-      (get-hash-table indices v #f))
-    (define (lowlink v)
-      (get-hash-table lowlinks v #f))
-
-    (letrec
-       ([strong-connect
-         (lambda (v)
-           (begin
-             (put-hash-table! indices v current)
-             (put-hash-table! lowlinks v current)
-             (set! current (+ current 1))
-             (push! stack v)
-             (put-hash-table! on-stack v #t)
-
-             (for-each
-              (lambda (w)
-                (if (not (hashtable-contains? indices w))
-                                       ; successor w has not been visited, recurse
-                    (begin
-                      (strong-connect w)
-                      (put-hash-table! lowlinks
-                                       v
-                                       (min (lowlink v) (lowlink w))))
-                                       ; successor w has been visited
-                    (when (get-hash-table on-stack w #f)
-                      (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
-              (successors graph v))
-
-             (when (= (index v) (lowlink v))
-               (let ([scc
-                      (let new-scc ()
-                        (let ([w (pop! stack)])
-                          (put-hash-table! on-stack w #f)
-                          (if (eqv? w v)
-                              (list w)
-                              (cons w (new-scc)))))])
-                 (set! result (cons scc result))))))])
-      (for-each
-       (lambda (v)
-        (when (not (hashtable-contains? indices v)) ; v.index == -1
-          (strong-connect v)))
-       (car graph)))
-    result))