1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
------------------------------------------------------------------------------
--- Library for representation of unification on first-order terms.
---
--- This library implements a unification algorithm using reference tables.
---
--- @author Michael Hanus, Jan-Hendrik Matthes, Jonas Oberschweiber,
---         Bjoern Peemoeller
--- @version February 2020
------------------------------------------------------------------------------

module Rewriting.Unification
  ( UnificationError (..)
  , showUnificationError, unify, unifiable
  ) where

import Data.Either (isRight)
import Data.List   (mapAccumL)
import qualified Data.Map as Map
import Rewriting.Substitution (Subst, emptySubst, extendSubst)
import Rewriting.Term (VarIdx, Term (..), TermEq, TermEqs)
import Rewriting.UnificationSpec (UnificationError (..), showUnificationError)

-- ---------------------------------------------------------------------------
-- Representation of internal data structures
-- ---------------------------------------------------------------------------

--- An `RTerm` is the unification algorithm's internal term representation.
--- Its `RTermVar` and `RTermCons` constructors are similar to the `TermVar`
--- and `TermCons` constructors of the original `Term` data type, but it has
--- an additional `Ref` constructor. This `Ref` constructor is used to
--- represent references into a reference table.
data RTerm f = Ref VarIdx | RTermVar VarIdx | RTermCons f [RTerm f]
 deriving (Eq, Show)

--- A reference table used to store the values referenced by `Ref` terms
--- represented as a finite map from variables to `RTerm`s and parameterized
--- over the kind of function symbols, e.g., strings.
type RefTable f = Map.Map VarIdx (RTerm f)

--- An `RTerm` equation represented as a pair of `RTerm`s and parameterized
--- over the kind of function symbols, e.g., strings.
type REq f = (RTerm f, RTerm f)

--- Multiple `RTerm` equations represented as a list of `RTerm` equations and
--- parameterized over the kind of function symbols, e.g., strings.
type REqs f = [REq f]

-- ---------------------------------------------------------------------------
-- Definition of exported functions
-- ---------------------------------------------------------------------------

--- Unifies a list of term equations. Returns either a unification error or a
--- substitution.
unify :: Eq f => TermEqs f -> Either (UnificationError f) (Subst f)
unify eqs = let (rt, reqs) = termEqsToREqs eqs
             in either Left
                       (\(rt', reqs') -> Right (eqsToSubst rt' reqs'))
                       (unify' rt [] reqs)

--- Checks whether a list of term equations can be unified.
unifiable :: Eq f => TermEqs f -> Bool
unifiable = isRight . unify

-- ---------------------------------------------------------------------------
-- Conversion to internal structure
-- ---------------------------------------------------------------------------

--- Converts a list of term equations into a list of `RTerm` equations and
--- places references into a fresh reference table.
termEqsToREqs :: TermEqs f -> (RefTable f, REqs f)
termEqsToREqs = mapAccumL termEqToREq Map.empty

--- Converts a term equation into an `RTerm` equation. The given reference
--- table is used to store references.
termEqToREq :: RefTable f -> TermEq f -> (RefTable f, REq f)
termEqToREq rt (l, r) = let (rt1, l') = termToRTerm rt l
                            (rt2, r') = termToRTerm rt1 r
                         in (rt2, (l', r'))

--- Converts a term to an `RTerm`, placing all variable terms in the given
--- reference table and replacing them by references inside the result
--- `RTerm`.
termToRTerm :: RefTable f -> Term f -> (RefTable f, RTerm f)
termToRTerm rt (TermVar v)     = (Map.insert v (RTermVar v) rt, Ref v)
termToRTerm rt (TermCons c ts) = let (rt', ts') = mapAccumL termToRTerm rt ts
                                  in (rt', RTermCons c ts')

-- ---------------------------------------------------------------------------
-- Conversion from internal structure
-- ---------------------------------------------------------------------------

--- Converts a list of `RTerm` equations to a substitution by turning every
--- equation of the form `(RTermVar v, t)` or `(t, RTermVar v)` into a mapping
--- `(v, t)`. Equations that do not have a variable term on either side are
--- ignored. Works on `RTerm`s, dereferences all `Ref`s.
eqsToSubst :: RefTable f -> REqs f -> Subst f
eqsToSubst _  []           = emptySubst
eqsToSubst rt ((l, r):eqs) = case l of
  Ref _         -> eqsToSubst rt ((deref rt l, r):eqs)
  RTermVar v    -> extendSubst (eqsToSubst rt eqs) v (rTermToTerm rt r)
  RTermCons _ _ -> case r of
    Ref _      -> eqsToSubst rt ((l, deref rt r):eqs)
    RTermVar v -> extendSubst (eqsToSubst rt eqs) v (rTermToTerm rt l)
    _          -> eqsToSubst rt eqs

--- Converts an `RTerm` to a term by dereferencing all references inside the
--- `RTerm`. The given reference table is used for reference lookups.
rTermToTerm :: RefTable f -> RTerm f -> Term f
rTermToTerm rt t@(Ref _)        = rTermToTerm rt (deref rt t)
rTermToTerm _  (RTermVar v)     = TermVar v
rTermToTerm rt (RTermCons c ts) = TermCons c (map (rTermToTerm rt) ts)

--- Dereferences an `RTerm` by following chained references. Simply returns
--- the same value for `RTermVar` and `RTermCons`. The given reference table
--- is used for reference lookups.
deref :: RefTable f -> RTerm f -> RTerm f
deref rt (Ref i)           = case Map.lookup i rt of
                               Nothing  -> error ("deref: " ++ (show i))
                               (Just t) -> case t of
                                             (Ref _)         -> deref rt t
                                             (RTermVar _)    -> t
                                             (RTermCons _ _) -> t
deref _  t@(RTermVar _)    = t
deref _  t@(RTermCons _ _) = t

-- ---------------------------------------------------------------------------
-- Unification algorithm
-- ---------------------------------------------------------------------------

--- Internal unification function, the core of the algorithm.
unify' :: Eq f => RefTable f -> REqs f -> REqs f
       -> Either (UnificationError f) (RefTable f, REqs f)
-- No equations left, we are done.
unify' rt sub []              = Right (rt, sub)
unify' rt sub (eq@(l, r):eqs) = case eq of
  -- Substitute the variable by the constructor term.
  (RTermVar v, RTermCons _ _)           -> elim rt sub v r eqs
  (RTermCons _ _, RTermVar v)           -> elim rt sub v l eqs
  -- If both variables are equal, simply remove the equation.
  -- Otherwise substitute the first variable by the second variable.
  (RTermVar v, RTermVar v') | v == v'   -> unify' rt sub eqs
                            | otherwise -> elim rt sub v r eqs
  -- If both constructors have the same name, equate their arguments.
  -- Otherwise fail with a clash.
  (RTermCons c1 ts1, RTermCons c2 ts2)
    | c1 == c2  -> unify' rt sub (zip ts1 ts2 ++ eqs)
    | otherwise -> Left (Clash (rTermToTerm rt l) (rTermToTerm rt r))
  -- If we encounter a `Ref`, simply dereference it and try again.
  _ -> unify' rt sub ((deref rt l, deref rt r):eqs)

--- Substitutes a variable by a term inside a list of equations that have
--- yet to be unified and the right-hand sides of all equations of the result
--- list. Also adds a mapping from that variable to that term to the result
--- list.
elim :: Eq f => RefTable f -> REqs f -> VarIdx -> RTerm f -> REqs f
     -> Either (UnificationError f) (RefTable f, REqs f)
elim rt sub v t eqs
  | dependsOn rt (RTermVar v) t = Left (OccurCheck v (rTermToTerm rt t))
  | otherwise
    = case t of
        (Ref _)         -> error "elim"
        -- Make sure to place a Ref in the reference table and substitution,
        -- not the RTermVar itself.
        (RTermVar v')   -> let rt' = Map.insert v (Ref v') rt
                            in unify' rt' ((RTermVar v, Ref v'):sub) eqs
        (RTermCons _ _) -> unify' (Map.insert v t rt) ((RTermVar v, t):sub) eqs

--- Checks whether the first term occurs as a subterm of the second term.
dependsOn :: Eq f => RefTable f -> RTerm f -> RTerm f -> Bool
dependsOn rt l r = l /= r && dependsOn' r
  where
    dependsOn' x@(Ref _)        = deref rt x == l
    dependsOn' t@(RTermVar _)   = l == t
    dependsOn' (RTermCons _ ts) = or (map dependsOn' ts)