From 38c3097bcf2d3c92a180c5d328cfb15ef80f0b95 Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Sun, 20 Sep 2020 06:59:27 +0200 Subject: Validate fragments are input types --- src/Language/GraphQL/Validate/Rules.hs | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) (limited to 'src/Language/GraphQL/Validate/Rules.hs') diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 3af6145..645a62e 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -23,9 +23,10 @@ module Language.GraphQL.Validate.Rules , uniqueFragmentNamesRule , uniqueOperationNamesRule , uniqueVariableNamesRule + , variablesAreInputTypesRule ) where -import Control.Monad (foldM) +import Control.Monad ((>=>), foldM) import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Reader (ReaderT, asks) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) @@ -67,6 +68,7 @@ specifiedRules = , uniqueDirectiveNamesRule -- Variables. , uniqueVariableNamesRule + , variablesAreInputTypesRule ] -- | Definition must be OperationDefinition or FragmentDefinition. @@ -505,3 +507,29 @@ uniqueVariableNamesRule = VariablesRule where extract (VariableDefinition variableName _ _ location) = (variableName, location) + +-- | Variables can only be input types. Objects, unions and interfaces cannot be +-- used as inputs. +variablesAreInputTypesRule :: forall m. Rule m +variablesAreInputTypesRule = VariablesRule + $ (traverse check . Seq.fromList) >=> lift + where + check (VariableDefinition name typeName _ location) + = asks types + >>= lift + . maybe (makeError name typeName location) (const mempty) + . lookupInputType typeName + makeError name typeName location = pure $ Error + { message = concat + [ "Variable \"$" + , Text.unpack name + , "\" cannot be non-input type \"" + , Text.unpack $ getTypeName typeName + , "\"." + ] + , locations = [location] + } + getTypeName (TypeNamed name) = name + getTypeName (TypeList name) = getTypeName name + getTypeName (TypeNonNull (NonNullTypeNamed nonNull)) = nonNull + getTypeName (TypeNonNull (NonNullTypeList nonNull)) = getTypeName nonNull -- cgit v1.2.3