Add notes on ownership
[scheme.git] / typecheck.scm
1 (load "ast.scm")
2
3 (define (abs? t)
4   (and (list? t) (eq? (car t) 'abs)))
5
6 (define (tvar? t)
7   (and (not (list? t)) (not (concrete? t)) (symbol? t)))
8
9 (define (concrete? t)
10   (case t
11     ('int #t)
12     ('bool #t)
13     ('void #t)
14     (else #f)))
15
16 (define (pretty-type t)
17   (cond ((abs? t)
18          (string-append
19           (if (abs? (cadr t))
20               (string-append "(" (pretty-type (cadr t)) ")")
21               (pretty-type (cadr t)))
22           " -> "
23           (pretty-type (caddr t))))
24         (else (symbol->string t))))
25
26                                         ; ('a, ('b, 'a))
27 (define (env-lookup env n)
28   (if (null? env) (error #f "empty env")                        ; it's a type equality
29       (if (eq? (caar env) n)
30           (cdar env)
31           (env-lookup (cdr env) n))))
32
33 (define (env-insert env n t)
34   (cons (cons n t) env))
35
36 (define abs-arg cadr)
37
38 (define cur-tvar 0)
39 (define (fresh-tvar)
40   (begin
41     (set! cur-tvar (+ cur-tvar 1))
42     (string->symbol
43      (string-append "t" (number->string (- cur-tvar 1))))))
44
45 (define (last xs)
46   (if (null? (cdr xs))
47       (car xs)
48       (last (cdr xs))))
49                                 
50 (define (normalize prog) ; (+ a b) -> ((+ a) b)
51   (case (ast-type prog)
52     ('lambda 
53                                         ; (lambda (x y) (+ x y)) -> (lambda (x) (lambda (y) (+ x y)))
54         (if (> (length (lambda-args prog)) 1)
55             (list 'lambda (list (car (lambda-args prog)))
56                   (normalize (list 'lambda (cdr (lambda-args prog)) (caddr prog))))
57             (list 'lambda (lambda-args prog) (normalize (caddr prog)))))
58     ('app
59      (if (null? (cddr prog))
60          `(,(normalize (car prog)) ,(normalize (cadr prog))) ; (f a)
61          (normalize `(,(list (normalize (car prog)) (normalize (cadr prog)))
62                       ,@(cddr prog))))) ; (f a b)
63     ('let
64         (append (list 'let
65                       (map (lambda (x) `(,(car x) ,(normalize (cadr x))))
66                            (let-bindings prog)))
67                 (map normalize (let-body prog))))
68     (else (ast-traverse normalize prog))))
69
70 (define (builtin-type x)
71   (case x
72     ('+ '(abs int (abs int int)))
73     ('- '(abs int (abs int int)))
74     ('* '(abs int (abs int int)))
75     ('! '(abs bool bool))
76     ('= '(abs int (abs int bool)))
77     ('bool->int '(abs bool int))
78     ('print '(abs string void))
79     (else #f)))
80
81 ; we typecheck the lambda calculus only (only single arg lambdas)
82 (define (typecheck prog)
83   (define (check env x)
84     ;; (display "check: ")
85     ;; (display x)
86     ;; (display "\n\t")
87     ;; (display env)
88     ;; (newline)
89     (let
90         ((res
91           (case (ast-type x)
92             ('int-literal (list '() 'int))
93             ('bool-literal (list '() 'bool))
94             ('string-literal (list '() 'string))
95             ('builtin (list '() (builtin-type x)))
96
97             ('if
98              (let* ((cond-type-res (check env (cadr x)))
99                     (then-type-res (check env (caddr x)))
100                     (else-type-res (check env (cadddr x)))
101                     (then-eq-else-cs (~ (cadr then-type-res)
102                                         (cadr else-type-res)))
103                     (cs (consolidate
104                          (car then-type-res)
105                          (consolidate (car else-type-res)
106                                       then-eq-else-cs)))
107                     (return-type (substitute cs (cadr then-type-res))))
108                (when (not (eqv? (cadr cond-type-res) 'bool))
109                  (error #f "if condition isn't bool"))
110                (list cs return-type)))
111             
112             ('var (list '() (env-lookup env x)))
113             ('let
114                                         ; takes in the current environment and a scc
115                                         ; returns new environment with scc's types added in
116               (let* ([components (reverse (sccs (graph (let-bindings x))))]
117                      [process-component
118                       (lambda (acc comps)
119                         (let*
120                                         ; create a new env with tvars for each component
121                                         ; e.g. scc of (x y)
122                                         ; scc-env = ((x . t0) (y . t1))
123                             ([scc-env
124                               (fold-left
125                                (lambda (acc c)
126                                  (env-insert acc c (fresh-tvar)))
127                                acc comps)]
128                                         ; typecheck each component
129                              [type-results
130                               (map
131                                (lambda (c)
132                                  (let ([body (cadr (assoc c (let-bindings x)))])
133                                    (check scc-env body)))
134                                comps)]
135                                         ; collect all the constraints in the scc
136                              [cs
137                               (fold-left
138                                (lambda (acc res c)
139                                  (consolidate
140                                   acc
141                                   (consolidate (car res)
142                                         ; unify with tvars from scc-env
143                                         ; result ~ tvar
144                                                (~ (cadr res) (env-lookup scc-env c)))))
145                                '() type-results comps)]
146                                         ; substitute *only* the bindings in this scc
147                              [new-env
148                               (map (lambda (x)
149                                      (if (memv (car x) comps)
150                                          (cons (car x) (substitute cs (cdr x)))
151                                          x))
152                                    scc-env)])
153                           new-env))]
154                      [new-env (fold-left process-component env components)])
155                 (check new-env (last (let-body x)))))
156             
157             ('lambda
158                 (let* [(new-env (env-insert env (lambda-arg x) (fresh-tvar)))
159
160                        (body-type-res (check new-env (lambda-body x)))
161                        (cs (car body-type-res))
162                        (subd-env (substitute-env (car body-type-res) new-env))
163                        (arg-type (env-lookup subd-env (lambda-arg x)))
164                        (resolved-arg-type (substitute cs arg-type))]
165                   ;; (display "lambda:\n\t")
166                   ;; (display prog)
167                   ;; (display "\n\t")
168                   ;; (display cs)
169                   ;; (display "\n\t")
170                   ;; (display resolved-arg-type)
171                   ;; (newline)
172                   (list (car body-type-res)
173                         (list 'abs
174                               resolved-arg-type
175                               (cadr body-type-res)))))
176             
177             ('app ; (f a)
178              (if (eqv? (car x) (cadr x))
179                                         ; recursive function (f f)
180                  (let* [(func-type (env-lookup env (car x)))
181                         (return-type (fresh-tvar))
182                         (other-func-type `(abs ,func-type ,return-type))
183                         (cs (~ func-type other-func-type))
184                         (resolved-return-type (substitute cs return-type))]
185                    (list cs resolved-return-type))
186
187                                         ; regular function
188                  (let* ((arg-type-res (check env (cadr x)))
189                         (arg-type (cadr arg-type-res))
190                         (func-type-res (check env (car x)))
191                         (func-type (cadr func-type-res))
192                         
193                                         ; f ~ a -> t0
194                         (func-c (~
195                                  func-type
196                                  (list 'abs
197                                        arg-type
198                                        (fresh-tvar))))
199                         (cs (consolidate
200                              (consolidate func-c (car arg-type-res))
201                              (car func-type-res)))
202                         
203                         (resolved-func-type (substitute cs func-type))
204                         (resolved-return-type (caddr resolved-func-type)))
205                    ;; (display "app:\n")
206                    ;; (display cs)
207                    ;; (display "\n")
208                    ;; (display func-type)
209                    ;; (display "\n")
210                    ;; (display resolved-func-type)
211                    ;; (display "\n")
212                    ;; (display arg-type-res)
213                    ;; (display "\n")
214                    (if (abs? resolved-func-type)
215                        (let ((return-type (substitute cs (caddr resolved-func-type))))
216                          (list cs return-type))
217                        (error #f "not a function"))))))))
218       ;; (display "result of ")
219       ;; (display x)
220       ;; (display ":\n\t")
221       ;; (display (pretty-type (cadr res)))
222       ;; (display "\n\t[")
223       ;; (display (car res))
224       ;; (display "]\n")
225       res))
226   (cadr (check '() (normalize prog))))
227
228                                         ; returns a list of pairs of constraints
229 (define (~ a b)
230   (let ([res (unify? a b)])
231     (if res
232         res
233         (error #f
234                (format "couldn't unify ~a ~~ ~a" a b)))))
235
236 (define (unify? a b)
237   (cond [(eq? a b) '()]
238         [(or (tvar? a) (tvar? b)) (list (list a b))]
239         [(and (abs? a) (abs? b))
240          (let* [(arg-cs (unify? (cadr a) (cadr b)))
241                 (body-cs (unify? (substitute arg-cs (caddr a))
242                                  (substitute arg-cs (caddr b))))]
243            (consolidate arg-cs body-cs))]
244         [else #f]))
245
246                                         ; TODO: what's the most appropriate substitution?
247                                         ; should all constraints just be limited to a pair?
248                                         ; this is currently horrific and i don't know what im doing.
249                                         ; should probably use ast-find here or during consolidation
250                                         ; to detect substitutions more than one layer deep
251                                         ; e.g. (abs t1 int) ~ (abs bool int)
252                                         ; substituting these constraints with t1 should resolve t1 with bool
253 (define (substitute cs t)
254                                         ; gets the first concrete type
255                                         ; otherwise returns the last type variable
256
257                                         ; removes t itself from cs, to prevent infinite recursion
258   (define cs-without-t
259     (map (lambda (c)
260            (filter (lambda (x) (not (eqv? t x))) c))
261          cs))
262
263   (define (get-concrete c)
264     (let [(last (null? (cdr c)))]
265       (if (not (tvar? (car c)))
266           (if (abs? (car c))
267               (substitute cs-without-t (car c))
268               (car c))
269           (if last
270               (car c)
271               (get-concrete (cdr c))))))
272   
273   (cond
274    ((abs? t) (list 'abs
275                    (substitute cs (cadr t))
276                    (substitute cs (caddr t))))
277    (else
278     (fold-left
279      (lambda (t c)
280        (if (member t c)
281            (get-concrete c)
282            t))
283      t cs))))
284
285 (define (substitute-env cs env)
286   (map (lambda (x) (cons (car x) (substitute cs (cdr x)))) env))
287
288 (define (consolidate x y)
289   (define (merge a b)
290     (cond ((null? a) b)
291           ((null? b) a)
292           (else (if (member (car b) a)
293                     (merge a (cdr b))
294                     (cons (car b) (merge a (cdr b)))))))
295   (define (overlap? a b)
296     (if (or (null? a) (null? b))
297         #f
298         (if (fold-left (lambda (acc v)
299                          (or acc (eq? v (car a))))
300                        #f b)
301             #t
302             (overlap? (cdr a) b))))
303
304   (cond ((null? y) x)
305         ((null? x) y)
306         (else
307          (let* ((a (car y))
308                 (merged (fold-left
309                          (lambda (acc b)
310                            (if acc
311                                acc
312                                (if (overlap? a b)
313                                    (cons (merge a b) b)
314                                    #f)))
315                          #f x))
316                 (removed (if merged
317                              (filter (lambda (b) (not (eq? b (cdr merged)))) x)
318                              x)))
319            (if merged
320                (consolidate removed (cons (car merged) (cdr y)))
321                (consolidate (cons a x) (cdr y)))))))
322
323                                         ; a1 -> a2 ~ a3 -> a4;
324                                         ; a1 -> a2 !~ bool -> bool
325                                         ; basically can the tvars be renamed
326 (define (types-equal? x y)
327   (let ([cs (unify? x y)])
328     (if (not cs) #f
329         (let*
330             ([test-kind
331               (lambda (acc c)
332                 (if (tvar? c) acc #f))]
333              [test (lambda (acc c)
334                      (and acc
335                           (fold-left test-kind #t c) ; check only tvar substitutions
336                           (<= (length c) 2)))])      ; check maximum 2 subs per equality group
337           (fold-left test #t cs)))))
338
339                                         ; input: a list of binds ((x . y) (y . 3))
340                                         ; returns: pair of verts, edges ((x y) . (x . y))
341 (define (graph bs)
342   (define (go bs orig-bs)
343     (define (find-refs prog)
344       (ast-collect
345        (lambda (x)
346          (case (ast-type x)
347                                         ; only count a reference if its a binding
348            ['var (if (assoc x orig-bs) (list x) '())]
349            [else '()]))
350        prog))
351     (if (null? bs)
352         '(() . ())
353         (let* [(bind (car bs))
354
355                (vert (car bind))
356                (refs (find-refs (cdr bind)))
357                (edges (map (lambda (x) (cons vert x))
358                            refs))
359
360                (rest (if (null? (cdr bs))
361                          (cons '() '())
362                          (go (cdr bs) orig-bs)))
363                (total-verts (cons vert (car rest)))
364                (total-edges (append edges (cdr rest)))]
365           (cons total-verts total-edges))))
366   (go bs bs))
367
368 (define (successors graph v)
369   (define (go v E)
370     (if (null? E)
371         '()
372         (if (eqv? v (caar E))
373             (cons (cdar E) (go v (cdr E)))
374             (go v (cdr E)))))
375   (go v (cdr graph)))
376
377                                         ; takes in a graph (pair of vertices, edges)
378                                         ; returns a list of strongly connected components
379
380                                         ; ((x y w) . ((x . y) (x . w) (w . x))
381
382                                         ; =>
383                                         ; .->x->y
384                                         ; |  |
385                                         ; |  v
386                                         ; .--w
387
388                                         ; ((x w) (y))
389
390                                         ; this uses tarjan's algorithm, to get reverse
391                                         ; topological sorting for free
392 (define (sccs graph)
393   
394   (let* ([indices (make-hash-table)]
395          [lowlinks (make-hash-table)]
396          [on-stack (make-hash-table)]
397          [current 0]
398          [stack '()]
399          [result '()])
400
401     (define (index v)
402       (get-hash-table indices v #f))
403     (define (lowlink v)
404       (get-hash-table lowlinks v #f))
405
406     (letrec
407         ([strong-connect
408           (lambda (v)
409             (begin
410               (put-hash-table! indices v current)
411               (put-hash-table! lowlinks v current)
412               (set! current (+ current 1))
413               (push! stack v)
414               (put-hash-table! on-stack v #t)
415
416               (for-each
417                (lambda (w)
418                  (if (not (hashtable-contains? indices w))
419                                         ; successor w has not been visited, recurse
420                      (begin
421                        (strong-connect w)
422                        (put-hash-table! lowlinks
423                                         v
424                                         (min (lowlink v) (lowlink w))))
425                                         ; successor w has been visited
426                      (when (get-hash-table on-stack w #f)
427                        (put-hash-table! lowlinks v (min (lowlink v) (index w))))))
428                (successors graph v))
429
430               (when (= (index v) (lowlink v))
431                 (let ([scc
432                        (let new-scc ()
433                          (let ([w (pop! stack)])
434                            (put-hash-table! on-stack w #f)
435                            (if (eqv? w v)
436                                (list w)
437                                (cons w (new-scc)))))])
438                   (set! result (cons scc result))))))])
439       (for-each
440        (lambda (v)
441          (when (not (hashtable-contains? indices v)) ; v.index == -1
442            (strong-connect v)))
443        (car graph)))
444     result))
445