;; -*- mode: scheme; coding: utf-8 -*-
;; SPDX-License-Identifier: AGPL-3.0-or-later
;; Loko Scheme - an R6RS Scheme compiler
;; Copyright © 2021 Göran Weinholt

;; This program is free software: you can redistribute it and/or modify
;; it under the terms of the GNU Affero General Public License as published by
;; the Free Software Foundation, either version 3 of the License, or
;; (at your option) any later version.

;; This program is distributed in the hope that it will be useful,
;; but WITHOUT ANY WARRANTY; without even the implied warranty of
;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
;; GNU Affero General Public License for more details.

;; You should have received a copy of the GNU Affero General Public License
;; along with this program.  If not, see <https://www.gnu.org/licenses/>.
#!r6rs

;;; Global (inter/intra-library) integrations

;; This runs immediately before the code generator and has access to
;; all code that makes up the final program.

(library (loko compiler global)
  (export
    pass-prune-globals
    pass-direct-calls)
  (import
    (loko compiler recordize)
    (rename (loko runtime utils) (map-in-order map))
    (except (rnrs) map)
    (only (psyntax compat) pretty-print))

;; Delete global variables if the name does not appear in any
;; $global-ref and is missing in primlocs (the reverse lookup table
;; locs-rev is used).
;; TODO: Right now this uses multiple passes, but maybe it's possible
;; instead to trace all reachable code?
(define (pass-prune-globals codes locs-ht locs-rev)
  (define who 'pass-prune-globals)
  (define changes #f)
  ;; Build a hashtable of every variable used in $global-ref, and
  ;; clear variable-referenced?.
  (define referenced (make-eq-hashtable))
  (define (pass-find-refs x)
    (cond ((bind? x)
           (for-each (lambda (x) (set-variable-referenced?! x #f)) (bind-lhs* x))
           (for-each pass-find-refs (bind-rhs* x))
           (pass-find-refs (bind-body x)))
          ((fix? x)
           (for-each (lambda (x) (set-variable-referenced?! x #f)) (fix-lhs* x))
           (for-each pass-find-refs (fix-rhs* x))
           (pass-find-refs (fix-body x)))
          ((proc? x)
           (for-each (lambda (x)
                       (pass-find-refs (proccase-body x)))
                     (proc-cases x)))
          ((seq? x)
           (pass-find-refs (seq-e0 x))
           (pass-find-refs (seq-e1 x)))
          ((mutate? x)
           (pass-find-refs (mutate-expr x)))
          ((test? x)
           (pass-find-refs (test-expr x))
           (pass-find-refs (test-then x))
           (pass-find-refs (test-else x)))
          ((funcall? x)
           (let ((op (funcall-operator x))
                 (operands (funcall-operand* x)))
             (pass-find-refs op)
             (for-each pass-find-refs operands)
             (when (and (primref? op) (eq? (primref-name op) '$global-ref))
               (hashtable-set! referenced (const-value (car operands)) #t))))
          ((mv-call? x)
           (pass-find-refs (mv-call-producer x))
           (pass-find-refs (mv-call-consumer x)))
          ((mv-let? x)
           (pass-find-refs (mv-let-expr x))
           (pass-find-refs (mv-let-body x)))
          ((mv-values? x)
           (for-each pass-find-refs (mv-values-expr* x)))
          ((const? x))
          ((ref? x))
          ((primref? x)
           (cond ((hashtable-ref locs-ht (primref-name x) #f) =>
                  (lambda (loc) (hashtable-set! referenced loc #t))))
           x)
          ((closure? x)
           (pass-find-refs (closure-code x)))
          ((goto? x))
          ((tagbody? x)
           (pass-find-refs (tagbody-body x)))
          ((infer? x)
           (pass-find-refs (infer-expr x)))
          ((labels? x)
           (pass-find-refs (labels-body x)))
          (else
           (error who "Unknown type" x))))

  (define (global-name-unused? loc)
    (and (not (hashtable-ref referenced loc #f))
         (not (hashtable-ref locs-rev loc #f))))

  ;; Filter the bindings that have an export-name, which initially
  ;; come from library-letrec*. After analyzing the whole program it
  ;; is possible to remove any binding that is never used.
  (define (filter-bindings lhs* rhs* body)
    (if (null? lhs*)
        (values lhs* rhs* body)
        (let-values ([(lhs*^ rhs*^ body^)
                      (filter-bindings (cdr lhs*) (cdr rhs*) body)])
          (let* ((lhs (car lhs*))
                 (rhs (car rhs*))
                 (export-name (variable-export-name lhs)))
            (cond ((and export-name (global-name-unused? export-name)
                        (not (variable-referenced? lhs)))
                   (set! changes #t)
                   (set-variable-export-name! lhs #f)
                   (cond ((closure? rhs)
                          (set-proc-label! (closure-code rhs) #f)
                          (values lhs*^ rhs*^ body))
                         ((or (const? rhs)
                              (and (infer? rhs) (const? (infer-expr rhs))))
                          (values lhs*^ rhs*^ body))
                         (else
                          (values lhs*^ rhs*^ (make-seq (car rhs*) body)))))
                  (else
                   (values (cons lhs lhs*^) (cons rhs rhs*^) body)))))))

  ;; Check if an expression that appears as the second argument of
  ;; $global-set! is free from side effects.
  (define (side-effect-free? x)
    (cond ((ref? x) #t)
          ((infer? x) (side-effect-free? (infer-expr x)))
          ((closure? x) #t)
          ((const? x) #t)
          ((and (funcall? x) (primref? (funcall-operator x)))
           (and (memq (primref-name (funcall-operator x))
                      '(record-mutator
                        record-accessor
                        record-predicate
                        record-constructor
                        make-record-constructor-descriptor))
                (for-all side-effect-free? (funcall-operand* x))))
          (else #f)))

  (define labels '())
  (define (pass x)
    (cond ((bind? x)
           ;; Filter bindings, but process the body & right-hand sides
           ;; first to update variable-referenced?.
           (let-values ([(lhs* rhs* body)
                         (filter-bindings (bind-lhs* x) (map pass (bind-rhs* x))
                                          (pass (bind-body x)))])
             (make-bind lhs* rhs* body)))
          ((fix? x)
           (let-values ([(lhs* rhs* body)
                         (filter-bindings (fix-lhs* x) (map pass (fix-rhs* x))
                                          (pass (fix-body x)))])
             (make-fix lhs* rhs* body)))
          ((proc? x)
           (make-proc (proc-label x)
                      (proc-end-label x)
                      (map (lambda (x)
                             (make-proccase (proccase-info x)
                                            (pass (proccase-body x))))
                           (proc-cases x))
                      (proc-free x)
                      (proc-name x)
                      (proc-source x)))
          ((seq? x)
           (make-seq (pass (seq-e0 x))
                     (pass (seq-e1 x))))
          ((mutate? x)
           (make-mutate (mutate-name x)
                        (pass (mutate-expr x))))
          ((test? x)
           (make-test (pass (test-expr x))
                      (pass (test-then x))
                      (pass (test-else x))))
          ((funcall? x)
           (let ((op (funcall-operator x))
                 (operands (funcall-operand* x)))
             (cond ((and (primref? op)
                         (eq? (primref-name op) '$global-set!)
                         (global-name-unused? (const-value (car operands))))
                    (let ((rhs (cadr operands)))
                      (if (side-effect-free? rhs)
                          (make-funcall (make-primref 'void) '() #f (funcall-source x))
                          (pass rhs))))
                   (else
                    (make-funcall (pass op)
                                  (map pass operands)
                                  (funcall-label x)
                                  (funcall-source x))))))
          ((mv-call? x)
           (make-mv-call (pass (mv-call-producer x))
                         (pass (mv-call-consumer x))
                         (mv-call-source x)))
          ((mv-let? x)
           (make-mv-let (pass (mv-let-expr x))
                        (mv-let-lhs* x)
                        (pass (mv-let-body x))
                        (mv-let-source x)))
          ((mv-values? x)
           (make-mv-values (map pass (mv-values-expr* x))
                           (mv-values-source x)))
          ((const? x) x)
          ((ref? x)
           (set-variable-referenced?! (ref-name x) #t)
           x)
          ((primref? x) x)
          ((goto? x) x)
          ((tagbody? x)
           (make-tagbody (tagbody-label x)
                         (pass (tagbody-body x))
                         (tagbody-source x)))
          ((infer? x)
           (make-infer (pass (infer-expr x))
                       (infer-facts x)))
          ((closure? x)
           (let ((code (pass (closure-code x))))
             (set! labels (cons code labels))
             (make-closure code (closure-free* x))))
          ((labels? x)
           (let ((body (pass (labels-body x))))
             (let ((proc* labels))
               ;; The whole proc can be pruned if the associated export
               ;; binding is pruned.
               (make-labels (labels-top-level-name x)
                            (filter proc-label proc*)
                            body))))
          (else
           (error who "Unknown type" x))))

  ;; Analyze & prune until fixpoint
  (for-each pass-find-refs codes)
  (let ((codes (map (lambda (x)
                      (set! labels '())
                      (pass x))
                    codes)))
    (if (not changes)
        codes
        (pass-prune-globals codes locs-ht locs-rev))))

(define (label-for-call proc arg-len)
  ;; This takes a proc and a list of operands, i.e. and returns the
  ;; proccase that would be called, or #f if none.
  (define (match? case)
    (let* ((info (proccase-info case))
           (formals (caseinfo-formals info)))
      (if (caseinfo-proper? info)
          (fx=? arg-len (length formals))
          (fx>=? arg-len (fx- (length formals) 1)))))
  (cond ((memp match? (proc-cases proc)) =>
         (lambda (case*)
           (caseinfo-label (proccase-info (car case*)))))
        (else (proc-label proc))))

;;; Look up the labels for calls

(define (pass-direct-calls codes locs-ht)
  (define who 'pass-direct-calls)
  ;; Find all exported variables from a labels record (originally from
  ;; library-letrec*). This is used to find the labels for function
  ;; calls to globals.
  (define exports (make-eq-hashtable))
  (define (handle-exported-variable lhs rhs)
    (cond ((variable-export-name lhs) =>
           (lambda (export-name)
             (hashtable-set! exports export-name rhs)))))
  (define (pass-find-exports x)
    (cond ((bind? x)
           (for-each handle-exported-variable (bind-lhs* x) (bind-rhs* x))
           (pass-find-exports (bind-body x)))
          ((fix? x)
           (for-each handle-exported-variable (fix-lhs* x) (fix-rhs* x))
           (pass-find-exports (fix-body x)))
          ((proc? x))
          ((seq? x)
           (pass-find-exports (seq-e0 x))
           (pass-find-exports (seq-e1 x)))
          ((mutate? x)
           ;; Need to know in the next step if the variable in a
           ;; tagbody is mutated.
           (set-variable-mutated?! (mutate-name x) #t))
          ((test? x)
           (pass-find-exports (test-expr x)))
          ((funcall? x))
          ((mv-call? x)
           (pass-find-exports (mv-call-producer x))
           (pass-find-exports (mv-call-consumer x)))
          ((mv-let? x)
           (pass-find-exports (mv-let-expr x))
           (pass-find-exports (mv-let-body x)))
          ((mv-values? x)
           (for-each pass-find-exports (mv-values-expr* x)))
          ((const? x))
          ((ref? x))
          ((primref? x))
          ((closure? x))
          ((goto? x))
          ((tagbody? x))
          ((infer? x))
          ((labels? x)
           (pass-find-exports (labels-body x)))
          (else
           (error who "Unknown type" x))))

  ;; Walk the code while keeping track of the variables that are in
  ;; scope (does not need to be precise). Replaces function calls to
  ;; primitives with primcall, replaces references to primitives with
  ;; $global-ref, looks up the labels for calls.
  (define labels '())
  (define (pass x env)
    (cond ((bind? x)
           (make-bind (bind-lhs* x)
                      (map (lambda (x) (pass x env)) (bind-rhs* x))
                      (pass (bind-body x)
                            (append (map cons (map variable-name (bind-lhs* x))
                                         (bind-rhs* x)) env))))
          ((fix? x)
           (let ((env (append (map cons (map variable-name (fix-lhs* x)) (fix-rhs* x)) env)))
             (make-fix (fix-lhs* x)
                       (map (lambda (x) (pass x env)) (fix-rhs* x))
                       (pass (fix-body x) env))))
          ((proc? x)
           (make-proc (proc-label x)
                      (proc-end-label x)
                      (map (lambda (x)
                             (make-proccase (proccase-info x)
                                            (pass (proccase-body x) env)))
                           (proc-cases x))
                      (proc-free x)
                      (proc-name x)
                      (proc-source x)))
          ((seq? x)
           (make-seq (pass (seq-e0 x) env)
                     (pass (seq-e1 x) env)))
          ((mutate? x)
           (make-mutate (mutate-name x)
                        (pass (mutate-expr x) env)))
          ((test? x)
           (make-test (pass (test-expr x) env)
                      (pass (test-then x) env)
                      (pass (test-else x) env)))
          ((funcall? x)
           (let ((op (funcall-operator x))
                 (operands (funcall-operand* x)))
             (let ((op-binding
                    (cond
                      ;; Maybe the operator is a reference to a
                      ;; procedure bound in the local environment.
                      ((and (ref? op) (not (variable-mutated? (ref-name op)))
                            (variable-name (ref-name op)))
                       => (lambda (varname)
                            (cond ((assq varname env) => cdr)
                                  (else #f))))
                      ;; Or maybe it's an exported procedure?
                      ((and (primref? op) (hashtable-ref locs-ht (primref-name op) #f))
                       => (lambda (loc)
                            (hashtable-ref exports loc #f)))
                      (else #f))))
               (let ((label (if (closure? op-binding)
                                (label-for-call (closure-code op-binding)
                                                (length operands))
                                (funcall-label x))))
                 (let ((op (if (primref? op) op (pass op env))))
                   (make-funcall op
                                 (map (lambda (x) (pass x env)) operands)
                                 label
                                 (funcall-source x)))))))
          ((mv-call? x)
           (make-mv-call (pass (mv-call-producer x) env)
                         (pass (mv-call-consumer x) env)
                         (mv-call-source x)))
          ((mv-let? x)
           (make-mv-let (pass (mv-let-expr x) env)
                        (mv-let-lhs* x)
                        (pass (mv-let-body x) env)
                        (mv-let-source x)))
          ((mv-values? x)
           (make-mv-values (map (lambda (x) (pass x env)) (mv-values-expr* x))
                           (mv-values-source x)))
          ((const? x) x)
          ((ref? x) x)
          ((primref? x)
           (cond ((hashtable-ref locs-ht (primref-name x) #f) =>
                  (lambda (loc)
                    (make-funcall (make-primref '$global-ref)
                                  (list (make-const loc #f))
                                  #f #f)))
                 (else
                  (error who "No closure defined for primitive" (primref-name x)))))
          ((goto? x) x)
          ((tagbody? x)
           (make-tagbody (tagbody-label x)
                         (pass (tagbody-body x) env)
                         (tagbody-source x)))
          ((infer? x)
           (make-infer (pass (infer-expr x) env)
                       (infer-facts x)))
          ((closure? x)
           (let ((code (pass (closure-code x) env)))
             (set! labels (cons code labels))
             (make-closure code (closure-free* x))))
          ((labels? x)
           (let ((body (pass (labels-body x) env)))
             (let ((proc* labels))
               (make-labels (labels-top-level-name x)
                            proc* body))))
          (else
           (error who "Unknown type" x))))

  (for-each pass-find-exports codes)

  (map (lambda (x)
         (set! labels '())
         (pass x '()))
       codes)))
