------------------------------------------------------------------------------
-- | Author:  Michael Hanus
--   Version: November 2025
--
-- This module contains operations to eliminate definitions and uses
-- of `newtype` in a FlatCurry program.
-- 
-- If there is a declaration of the form
-- 
--     newtype NTYPE a1...an = NTCONS te
-- 
-- in a Curry program, the following transformations are preformed:
-- 
-- * Replace `newtype` declaration by `data` declaration.
-- * Every type application `(NTYPE t1...tn)` is replaced by
--   `{a1 |-> t1,..., an |-> tn}(te)`
--   except for occurrences in instance definitions, i.e., operations
--   named by `_inst#...`.
-- * A constructor application `(NTCONS e)` is replaced by `e`.
-- * A partial constructor application `(NTCONS)` is replaced by
--   `(Prelude.id)`.
-- * A case expresion `(f)case x of { NTCONS y -> e}` is replaced by
--   `{y |-> x}(e)`.
--
------------------------------------------------------------------------------

module FlatCurry.ElimNewtype
  ( elimNewtypeInProg, elimNewtypeInProgWithImports )
 where

import Data.List          ( isPrefixOf )

import FlatCurry.AddTypes
import FlatCurry.Files    ( readFlatCurryInt )
import FlatCurry.Goodies  ( progImports, progTypes )
import FlatCurry.Types

-- | Eliminates all `newtype` definitions/uses in a FlatCurry program.
--   For this purpose, the interfaces of the imported modules are read
--   before performing the transformation.
elimNewtypeInProg ::
     Prog    -- ^ program to be transformed
  -> IO Prog -- ^ returns transformed program without using newtype declarations
elimNewtypeInProg prog = do
  impints <- mapM readFlatCurryInt (progImports prog)
  return $ elimNewtypeInProgWithImports impints prog

-- | Eliminates all `newtype` definitions/uses in a FlatCurry program.
--   The first argument are the interfaces of the imported modules.
elimNewtypeInProgWithImports ::
     [Prog]  -- ^ interfaces/programs of all imported modules
  -> Prog    -- ^ program to be transformed
  -> Prog    -- ^ transformed program without using newtype declarations
elimNewtypeInProgWithImports impprogs prog =
  if null nti
    then prog
    else fromAProg
           (elimNewtypeInAProg nti (addTypesInProgWithImports impprogs prog))
 where
  nti = newtypesOfProg (concatMap progTypes (prog:impprogs))

-- Eliminates all `newtype` definitions/uses in a type-annotated
-- FlatCurry program.
elimNewtypeInAProg :: [NewtypeInfo] -> AProg TypeInfo -> AProg TypeInfo
elimNewtypeInAProg nti (AProg mname imps tdecls fdecls ops) =
  AProg mname imps (map replaceNewtypeDecl tdecls)
              (map (elimNewtypeInFunc nti) fdecls) ops

replaceNewtypeDecl :: TypeDecl -> TypeDecl
replaceNewtypeDecl td = case td of
  TypeNew tc tvis tvs (NewCons ct cvis te)
    -> Type tc tvis tvs [Cons ct 1 cvis [te]]
  _ -> td

elimNewtypeInFunc :: [NewtypeInfo] -> AFuncDecl TypeInfo -> AFuncDecl TypeInfo
elimNewtypeInFunc _   fd@(AFunc _  _  _   _     (AExternal _))    = fd
elimNewtypeInFunc nti fd@(AFunc qf ar vis ftype (ARule args rhs)) =
  if isClassInstanceOp qf
    then fd
    else AFunc qf ar vis (elimType ftype) (ARule args (elimExp rhs))
 where
  elimType te = case te of
    TVar _             -> te
    FuncType t1 t2     -> FuncType (elimType t1) (elimType t2)
    TCons tc tes       -> elimTCons tc (map elimType tes)
    ForallType tvs fte -> ForallType tvs (elimType fte)

  elimTCons tc tes =
    maybe (TCons tc tes)
          (\ (tvs,_,ntexp) -> substTVarsInTExp (zip tvs tes) ntexp)
          (lookup tc nti)

  elimTInfo ti = typeInfo (map (\(v,t) -> (v, elimType t)) (tiTypedVars ti))
                          (elimType (tiType ti))

  elimExp exp = case exp of
    AVar  _ v        -> AVar ti v
    ALit  _ l        -> ALit ti l
    AComb _ ct qn es -> elimComb ti ct qn (map elimExp es)
    ALet  _ bs e     -> ALet ti (map (\ (v,t,be) -> (v, t, elimExp be)) bs)
                                 (elimExp e)
    AFree _ vs e     -> AFree ti vs (elimExp e)
    AOr   _ e1 e2    -> AOr ti (elimExp e1) (elimExp e2)
    ACase _ ct ce bs -> elimCase ti ct (elimExp ce)
                         (map (\ (ABranch pt be) -> ABranch pt (elimExp be)) bs)
    ATyped _ e t     -> ATyped ti (elimExp e) t
   where ti = elimTInfo (annOfAExpr exp)

  elimComb ti ct qn es = case ct of
    ConsCall       | length es == 1 && isNewCons qn
      -> head es
    ConsPartCall 1 | null es && isNewCons qn
      -> AComb ti (FuncPartCall 1) ("Prelude","id") []
    _ -> AComb ti ct qn es

  -- Eliminate cases on newtype and replaces them by a let expression
  -- if the discriminating argument is not a variable (note that this
  -- requires information about the type of the discriminating argument).
  elimCase ti ct ce bs = case bs of
    [ABranch (Pattern qn [v]) be] | isNewCons qn
      -> case ce of AVar _ cv -> substVarInExp v cv be
                    _         -> ALet ti [(v, tiType (annOfAExpr ce), ce)] be
    _ -> ACase ti ct ce bs

  isNewCons qn = qn `elem` map (\ (_,(_,nc,_)) -> nc) nti

-- Applies a type substitution (first argument) to a type expression.
substTVarsInTExp :: [(TVarIndex,TypeExpr)] -> TypeExpr -> TypeExpr
substTVarsInTExp tvtexps te = subst te
 where
  subst texp = case texp of
    TVar v             -> maybe texp id (lookup v tvtexps)
    FuncType t1 t2     -> FuncType (subst t1) (subst t2)
    TCons tc tes       -> TCons tc (map subst tes)
    ForallType tvs fte -> ForallType tvs (subst fte)

-- Replaces a variable by another variable in an expressions, i.e.,
-- `substVarInExp x y e = {x |-> y}(e)`.
substVarInExp :: VarIndex -> VarIndex -> AExpr TypeInfo -> AExpr TypeInfo
substVarInExp x y e0 = subst e0
 where
  subst exp = case exp of
    AVar ti v         -> AVar ti (if v == x then y else v)
    ALit _ _          -> exp
    AComb ti ct qn es -> AComb ti ct qn (map subst es)
    ALet ti bs e      -> ALet ti (map (\ (v,te,be) -> (v, te, subst be)) bs)
                                      (subst e)
    AFree ti vs e     -> AFree ti vs (subst e)
    AOr ti e1 e2      -> AOr ti (subst e1) (subst e2)
    ACase ti ct ce bs -> ACase ti ct (subst ce)
                          (map (\ (ABranch pt be) -> ABranch pt (subst be)) bs)
    ATyped ti e t     -> ATyped ti (subst e) t


type NewtypeInfo = (QName,([TVarIndex],QName,TypeExpr))

-- Extracts `newtype` definitions occurring in type declarations.
newtypesOfProg :: [TypeDecl] -> [NewtypeInfo]
newtypesOfProg = concatMap ntOfTypeDecl
 where
  ntOfTypeDecl (Type    _  _ _   _                ) = []
  ntOfTypeDecl (TypeSyn _  _ _   _                ) = []
  ntOfTypeDecl (TypeNew tc _ tvs (NewCons ct _ te)) =
    [(tc, (map fst tvs, ct, te))]

isNewtypeDecl :: TypeDecl -> Bool
isNewtypeDecl td = case td of TypeNew _ _ _ _ -> True
                              _               -> False

-- Is the operation a class instance operation?
isClassInstanceOp :: QName -> Bool
isClassInstanceOp (_,f) = "_inst#" `isPrefixOf` f

-----------------------------------------------------------------------
