{-# LANGUAGE CPP, Rank2Types #-}
module Data.Functor.Foldable.TH
( makeBaseFunctor
, makeBaseFunctorWith
, BaseRules
, baseRules
, baseRulesType
, baseRulesCon
, baseRulesField
) where
import Control.Applicative as A
import Control.Monad
import Data.Traversable as T
import Data.Functor.Identity
import Language.Haskell.TH
import Language.Haskell.TH.Datatype as TH.Abs
import Language.Haskell.TH.Syntax (mkNameG_tc, mkNameG_v)
import Data.Char (GeneralCategory (..), generalCategory)
import Data.Orphans ()
#ifndef CURRENT_PACKAGE_KEY
import Data.Version (showVersion)
import Paths_recursion_schemes (version)
#endif
makeBaseFunctor :: Name -> DecsQ
makeBaseFunctor = makeBaseFunctorWith baseRules
makeBaseFunctorWith :: BaseRules -> Name -> DecsQ
makeBaseFunctorWith rules name = reifyDatatype name >>= makePrimForDI rules
data BaseRules = BaseRules
{ _baseRulesType :: Name -> Name
, _baseRulesCon :: Name -> Name
, _baseRulesField :: Name -> Name
}
baseRules :: BaseRules
baseRules = BaseRules
{ _baseRulesType = toFName
, _baseRulesCon = toFName
, _baseRulesField = toFName
}
baseRulesType :: Functor f => ((Name -> Name) -> f (Name -> Name)) -> BaseRules -> f BaseRules
baseRulesType f rules = (\x -> rules { _baseRulesType = x }) <$> f (_baseRulesType rules)
baseRulesCon :: Functor f => ((Name -> Name) -> f (Name -> Name)) -> BaseRules -> f BaseRules
baseRulesCon f rules = (\x -> rules { _baseRulesCon = x }) <$> f (_baseRulesCon rules)
baseRulesField :: Functor f => ((Name -> Name) -> f (Name -> Name)) -> BaseRules -> f BaseRules
baseRulesField f rules = (\x -> rules { _baseRulesField = x }) <$> f (_baseRulesField rules)
toFName :: Name -> Name
toFName = mkName . f . nameBase
where
f name | isInfixName name = name ++ "$"
| otherwise = name ++ "F"
isInfixName :: String -> Bool
isInfixName = all isSymbolChar
makePrimForDI :: BaseRules -> DatatypeInfo -> DecsQ
makePrimForDI rules
(DatatypeInfo { datatypeName = tyName
, datatypeVars = vars
, datatypeCons = cons
, datatypeVariant = variant }) = do
when isDataFamInstance $
fail "makeBaseFunctor: Data families are currently not supported."
makePrimForDI' rules (variant == Newtype) tyName
(map toTyVarBndr vars) cons
where
isDataFamInstance = case variant of
DataInstance -> True
NewtypeInstance -> True
Datatype -> False
Newtype -> False
toTyVarBndr :: Type -> TyVarBndr
toTyVarBndr (VarT n) = PlainTV n
toTyVarBndr (SigT (VarT n) k) = KindedTV n k
toTyVarBndr _ = error "toTyVarBndr"
makePrimForDI' :: BaseRules -> Bool -> Name -> [TyVarBndr]
-> [ConstructorInfo] -> DecsQ
makePrimForDI' rules isNewtype tyName vars cons = do
let vars' = map VarT (typeVars vars)
let tyNameF = _baseRulesType rules tyName
let s = conAppsT tyName vars'
rName <- newName "r"
let r = VarT rName
let varsF = vars ++ [PlainTV rName]
cons' <- traverse (conTypeTraversal resolveTypeSynonyms) cons
let consF
= toCon
. conNameMap (_baseRulesCon rules)
. conFieldNameMap (_baseRulesField rules)
. conTypeMap (substType s r)
<$> cons'
let dataDec = case consF of
#if MIN_VERSION_template_haskell(2,11,0)
[conF] | isNewtype ->
NewtypeD [] tyNameF varsF Nothing conF deriveds
_ ->
DataD [] tyNameF varsF Nothing consF deriveds
#else
[conF] | isNewtype ->
NewtypeD [] tyNameF varsF conF deriveds
_ ->
DataD [] tyNameF varsF consF deriveds
#endif
where
deriveds =
#if MIN_VERSION_template_haskell(2,12,0)
[DerivClause Nothing
[ ConT functorTypeName
, ConT foldableTypeName
, ConT traversableTypeName ]]
#elif MIN_VERSION_template_haskell(2,11,0)
[ ConT functorTypeName
, ConT foldableTypeName
, ConT traversableTypeName ]
#else
[functorTypeName, foldableTypeName, traversableTypeName]
#endif
#if MIN_VERSION_template_haskell(2,9,0)
let baseDec = TySynInstD baseTypeName (TySynEqn [s] $ conAppsT tyNameF vars')
#else
let baseDec = TySynInstD baseTypeName [s] $ conAppsT tyNameF vars'
#endif
projDec <- FunD projectValName <$> mkMorphism id (_baseRulesCon rules) cons'
#if MIN_VERSION_template_haskell(2,11,0)
let recursiveDec = InstanceD Nothing [] (ConT recursiveTypeName `AppT` s) [projDec]
#else
let recursiveDec = InstanceD [] (ConT recursiveTypeName `AppT` s) [projDec]
#endif
embedDec <- FunD embedValName <$> mkMorphism (_baseRulesCon rules) id cons'
#if MIN_VERSION_template_haskell(2,11,0)
let corecursiveDec = InstanceD Nothing [] (ConT corecursiveTypeName `AppT` s) [embedDec]
#else
let corecursiveDec = InstanceD [] (ConT corecursiveTypeName `AppT` s) [embedDec]
#endif
A.pure [dataDec, baseDec, recursiveDec, corecursiveDec]
mkMorphism
:: (Name -> Name)
-> (Name -> Name)
-> [ConstructorInfo]
-> Q [Clause]
mkMorphism nFrom nTo args = for args $ \ci -> do
let n = constructorName ci
fs <- replicateM (length (constructorFields ci)) (newName "x")
pure $ Clause [ConP (nFrom n) (map VarP fs)]
(NormalB $ foldl AppE (ConE $ nTo n) (map VarE fs))
[]
conNameTraversal :: Traversal' ConstructorInfo Name
conNameTraversal = lens constructorName (\s v -> s { constructorName = v })
conFieldNameTraversal :: Traversal' ConstructorInfo Name
conFieldNameTraversal = lens constructorVariant (\s v -> s { constructorVariant = v })
. conVariantTraversal
where
conVariantTraversal :: Traversal' ConstructorVariant Name
conVariantTraversal _ NormalConstructor = pure NormalConstructor
conVariantTraversal _ InfixConstructor = pure InfixConstructor
conVariantTraversal f (RecordConstructor fs) = RecordConstructor <$> traverse f fs
conTypeTraversal :: Traversal' ConstructorInfo Type
conTypeTraversal = lens constructorFields (\s v -> s { constructorFields = v })
. traverse
conNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conNameMap = over conNameTraversal
conFieldNameMap :: (Name -> Name) -> ConstructorInfo -> ConstructorInfo
conFieldNameMap = over conFieldNameTraversal
conTypeMap :: (Type -> Type) -> ConstructorInfo -> ConstructorInfo
conTypeMap = over conTypeTraversal
type Lens' s a = forall f. Functor f => (a -> f a) -> s -> f s
type Traversal' s a = forall f. Applicative f => (a -> f a) -> s -> f s
lens :: (s -> a) -> (s -> a -> s) -> Lens' s a
lens sa sas afa s = sas s <$> afa (sa s)
{-# INLINE lens #-}
over :: Traversal' s a -> (a -> a) -> s -> s
over l f = runIdentity . l (Identity . f)
{-# INLINE over #-}
typeVars :: [TyVarBndr] -> [Name]
typeVars = map tvName
conAppsT :: Name -> [Type] -> Type
conAppsT conName = foldl AppT (ConT conName)
substType
:: Type
-> Type
-> Type
-> Type
substType a b = go
where
go x | x == a = b
go (VarT n) = VarT n
go (AppT l r) = AppT (go l) (go r)
go (ForallT xs ctx t) = ForallT xs ctx (go t)
go (SigT t k) = SigT (go t) k
#if MIN_VERSION_template_haskell(2,11,0)
go (InfixT l n r) = InfixT (go l) n (go r)
go (UInfixT l n r) = UInfixT (go l) n (go r)
go (ParensT t) = ParensT (go t)
#endif
go x = x
toCon :: ConstructorInfo -> Con
toCon (ConstructorInfo { constructorName = name
, constructorVars = vars
, constructorContext = ctxt
, constructorFields = ftys
, constructorStrictness = fstricts
, constructorVariant = variant })
| not (null vars && null ctxt)
= error "makeBaseFunctor: GADTs are not currently supported."
| otherwise
= let bangs = map toBang fstricts
in case variant of
NormalConstructor -> NormalC name $ zip bangs ftys
RecordConstructor fnames -> RecC name $ zip3 fnames bangs ftys
InfixConstructor -> let [bang1, bang2] = bangs
[fty1, fty2] = ftys
in InfixC (bang1, fty1) name (bang2, fty2)
where
#if MIN_VERSION_template_haskell(2,11,0)
toBang (FieldStrictness upkd strct) = Bang (toSourceUnpackedness upkd)
(toSourceStrictness strct)
where
toSourceUnpackedness :: Unpackedness -> SourceUnpackedness
toSourceUnpackedness UnspecifiedUnpackedness = NoSourceUnpackedness
toSourceUnpackedness NoUnpack = SourceNoUnpack
toSourceUnpackedness Unpack = SourceUnpack
toSourceStrictness :: Strictness -> SourceStrictness
toSourceStrictness UnspecifiedStrictness = NoSourceStrictness
toSourceStrictness Lazy = SourceLazy
toSourceStrictness TH.Abs.Strict = SourceStrict
#else
toBang (FieldStrictness UnspecifiedUnpackedness Strict) = IsStrict
toBang (FieldStrictness UnspecifiedUnpackedness UnspecifiedStrictness) = NotStrict
toBang (FieldStrictness Unpack Strict) = Unpacked
toBang FieldStrictness{} = NotStrict
#endif
isSymbolChar :: Char -> Bool
isSymbolChar c = not (isPuncChar c) && case generalCategory c of
MathSymbol -> True
CurrencySymbol -> True
ModifierSymbol -> True
OtherSymbol -> True
DashPunctuation -> True
OtherPunctuation -> c `notElem` "'\""
ConnectorPunctuation -> c /= '_'
_ -> False
isPuncChar :: Char -> Bool
isPuncChar c = c `elem` ",;()[]{}`"
rsPackageKey :: String
#ifdef CURRENT_PACKAGE_KEY
rsPackageKey = CURRENT_PACKAGE_KEY
#else
rsPackageKey = "recursion-schemes-" ++ showVersion version
#endif
mkRsName_tc :: String -> String -> Name
mkRsName_tc = mkNameG_tc rsPackageKey
mkRsName_v :: String -> String -> Name
mkRsName_v = mkNameG_v rsPackageKey
baseTypeName :: Name
baseTypeName = mkRsName_tc "Data.Functor.Foldable" "Base"
recursiveTypeName :: Name
recursiveTypeName = mkRsName_tc "Data.Functor.Foldable" "Recursive"
corecursiveTypeName :: Name
corecursiveTypeName = mkRsName_tc "Data.Functor.Foldable" "Corecursive"
projectValName :: Name
projectValName = mkRsName_v "Data.Functor.Foldable" "project"
embedValName :: Name
embedValName = mkRsName_v "Data.Functor.Foldable" "embed"
functorTypeName :: Name
functorTypeName = mkNameG_tc "base" "GHC.Base" "Functor"
foldableTypeName :: Name
foldableTypeName = mkNameG_tc "base" "Data.Foldable" "Foldable"
traversableTypeName :: Name
traversableTypeName = mkNameG_tc "base" "Data.Traversable" "Traversable"