@ -37,6 +37,7 @@ module Language.GraphQL.Validate.Rules
, uniqueInputFieldNamesRule
, uniqueOperationNamesRule
, uniqueVariableNamesRule
, variablesInAllowedPositionRule
, variablesAreInputTypesRule
) where
@ -45,13 +46,13 @@ import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Reader ( ReaderT ( .. ) , ask , asks , mapReaderT )
import Control.Monad.Trans.State ( StateT , evalStateT , gets , modify )
import Data.Bifunctor ( first )
import Data.Foldable ( find , foldl' , toList )
import Data.Foldable ( find , fold , foldl' , toList )
import qualified Data.HashMap.Strict as HashMap
import Data.HashMap.Strict ( HashMap )
import Data.HashSet ( HashSet )
import qualified Data.HashSet as HashSet
import Data.List ( groupBy , sortBy , sortOn )
import Data.Maybe ( fromMaybe , isNothing , mapMaybe )
import Data.Maybe ( catMaybes , fromMaybe , isJust , isNothing , mapMaybe )
import Data.List.NonEmpty ( NonEmpty ( .. ) )
import Data.Ord ( comparing )
import Data.Sequence ( Seq ( .. ) , ( |> ) )
@ -108,6 +109,7 @@ specifiedRules =
, variablesAreInputTypesRule
, noUndefinedVariablesRule
, noUnusedVariablesRule
, variablesInAllowedPositionRule
]
-- | Definition must be OperationDefinition or FragmentDefinition.
@ -651,12 +653,9 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas
lift $ lift $ mapArguments arguments <> mapDirectives directives'
variableFilter ( Full . FragmentSpreadSelection spread )
| Full . FragmentSpread fragmentName _ _ <- spread = do
d efinitions <- lift $ asks ast
visited <- gets ( HashSet . member fragmentName )
modify ( HashSet . insert fragmentName )
case find ( isSpreadTarget fragmentName ) definitions of
Just ( viewFragment -> Just fragmentDefinition )
| not visited -> diveIntoSpread fragmentDefinition
nonVisitedFragmentD efinition <- visitFragmentDefinition fragmentName
case nonVisitedFragmentDefinition of
Just fragmentDefinition -> diveIntoSpread fragmentDefinition
_ -> lift $ lift mempty
diveIntoSpread ( Full . FragmentDefinition _ _ directives' selections _ )
= filterSelections' selections
@ -1286,3 +1285,199 @@ findSpreadTarget fragmentName = do
let Full . FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition
in Just typeCondition
extractTypeCondition _ = Nothing
visitFragmentDefinition :: forall m
. Text
-> ValidationState m ( Maybe Full . FragmentDefinition )
visitFragmentDefinition fragmentName = do
definitions <- lift $ asks ast
visited <- gets ( HashSet . member fragmentName )
modify ( HashSet . insert fragmentName )
case find ( isSpreadTarget fragmentName ) definitions of
Just ( viewFragment -> Just fragmentDefinition )
| not visited -> pure $ Just fragmentDefinition
_ -> pure Nothing
-- | Variable usages must be compatible with the arguments they are passed to.
--
-- Validation failures occur when variables are used in the context of types
-- that are complete mismatches, or if a nullable type in a variable is passed
-- to a non‐ null argument type.
variablesInAllowedPositionRule :: forall m . Rule m
variablesInAllowedPositionRule = OperationDefinitionRule $ \ case
Full . OperationDefinition operationType _ variables _ selectionSet _ -> do
schema' <- asks schema
let root = go variables ( toList selectionSet ) . Type . CompositeObjectType
case operationType of
Full . Query -> root $ Schema . query schema'
Full . Mutation
| Just objectType <- Schema . mutation schema' -> root objectType
Full . Subscription
| Just objectType <- Schema . mutation schema' -> root objectType
_ -> lift mempty
_ -> lift mempty
where
go variables selections selectionType = mapReaderT ( foldr ( <> ) Seq . empty )
$ flip evalStateT HashSet . empty
$ visitSelectionSet variables selectionType
$ toList selections
visitSelectionSet :: Foldable t
=> [ Full . VariableDefinition ]
-> Type . CompositeType m
-> t Full . Selection
-> ValidationState m ( Seq Error )
visitSelectionSet variables selectionType selections =
foldM ( evaluateSelection variables selectionType ) mempty selections
evaluateSelection :: [ Full . VariableDefinition ]
-> Type . CompositeType m
-> Seq Error
-> Full . Selection
-> ValidationState m ( Seq Error )
evaluateSelection variables selectionType accumulator selection
| Full . FragmentSpreadSelection spread <- selection
, Full . FragmentSpread fragmentName _ _ <- spread = do
types' <- lift $ asks $ Schema . types . schema
nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName
case nonVisitedFragmentDefinition of
Just fragmentDefinition
| Full . FragmentDefinition _ typeCondition _ _ _ <- fragmentDefinition
, Just spreadType <- Type . lookupTypeCondition typeCondition types' -> do
a <- spreadVariables variables spread
b <- diveIntoSpread variables spreadType fragmentDefinition
pure $ accumulator <> a <> b
_ -> lift $ lift mempty
| Full . FieldSelection fieldSelection <- selection
, Full . Field _ fieldName _ _ subselections _ <- fieldSelection =
case Type . lookupCompositeField fieldName selectionType of
Just ( Out . Field _ typeField argumentTypes ) -> do
a <- fieldVariables variables argumentTypes fieldSelection
case Type . outToComposite typeField of
Just newParentType -> do
b <- foldM ( evaluateSelection variables newParentType ) accumulator subselections
pure $ accumulator <> a <> b
Nothing -> pure $ accumulator <> a
Nothing -> pure accumulator
| Full . InlineFragmentSelection inlineSelection <- selection
, Full . InlineFragment typeCondition _ subselections _ <- inlineSelection = do
types' <- lift $ asks $ Schema . types . schema
let inlineType = fromMaybe selectionType
$ typeCondition
>>= flip Type . lookupTypeCondition types'
a <- inlineVariables variables inlineSelection
b <- foldM ( evaluateSelection variables inlineType ) accumulator subselections
pure $ accumulator <> a <> b
inlineVariables variables inline
| Full . InlineFragment _ directives' _ _ <- inline =
mapDirectives variables directives'
fieldVariables :: [ Full . VariableDefinition ]
-> In . Arguments
-> Full . Field
-> ValidationState m ( Seq Error )
fieldVariables variables argumentTypes fieldSelection = do
let Full . Field _ _ arguments directives' _ _ = fieldSelection
argumentErrors <- mapArguments variables argumentTypes arguments
directiveErrors <- mapDirectives variables directives'
pure $ argumentErrors <> directiveErrors
spreadVariables variables ( Full . FragmentSpread _ directives' _ ) =
mapDirectives variables directives'
diveIntoSpread variables fieldType fragmentDefinition = do
let Full . FragmentDefinition _ _ directives' selections _ =
fragmentDefinition
selectionErrors <- visitSelectionSet variables fieldType selections
directiveErrors <- mapDirectives variables directives'
pure $ selectionErrors <> directiveErrors
findDirectiveVariables variables directive = do
let Full . Directive directiveName arguments _ = directive
directiveDefinitions <- lift $ asks $ Schema . directives . schema
case HashMap . lookup directiveName directiveDefinitions of
Just ( Schema . Directive _ _ directiveArguments ) ->
mapArguments variables directiveArguments arguments
Nothing -> pure mempty
mapArguments variables argumentTypes
= fmap ( Seq . fromList . catMaybes )
. traverse ( findArgumentVariables variables argumentTypes )
mapDirectives variables = fmap fold
<$> traverse ( findDirectiveVariables variables )
findArgumentVariables variables argumentTypes argument
| Full . Argument argumentName argumentValue _ <- argument
, Full . Node { node = Full . Variable variableName , .. } <- argumentValue
, Just expectedType <- HashMap . lookup argumentName argumentTypes
, findVariableDefinition' <- findVariableDefinition variableName
, Just variableDefinition <- find findVariableDefinition' variables =
isVariableUsageAllowed expectedType variableDefinition
| otherwise = pure Nothing
findVariableDefinition variableName variableDefinition =
let Full . VariableDefinition variableName' _ _ _ = variableDefinition
in variableName == variableName'
isVariableUsageAllowed ( In . Argument _ locationType locationDefaultValue ) variableDefinition @ ( Full . VariableDefinition _ variableType variableDefaultValue _ )
| Full . TypeNonNull _ <- variableType =
typesCompatibleOrError variableDefinition locationType
| Just nullableLocationType <- unwrapInType locationType
, hasNonNullVariableDefaultValue' <- hasNonNullVariableDefaultValue variableDefaultValue
, hasLocationDefaultValue <- isJust locationDefaultValue =
if hasNonNullVariableDefaultValue' || hasLocationDefaultValue
then typesCompatibleOrError variableDefinition nullableLocationType
else pure $ Just $ makeError variableDefinition locationType
| otherwise = typesCompatibleOrError variableDefinition locationType
typesCompatibleOrError variableDefinition locationType
| Full . VariableDefinition _ variableType _ _ <- variableDefinition
, areTypesCompatible variableType locationType = pure Nothing
| otherwise = pure $ Just $ makeError variableDefinition locationType
areTypesCompatible ( Full . TypeNonNull nonNullType ) ( unwrapInType -> Just nullableLocationType ) =
case nonNullType of
Full . NonNullTypeNamed n ->
areTypesCompatible ( Full . TypeNamed n ) nullableLocationType
Full . NonNullTypeList n ->
areTypesCompatible ( Full . TypeList n ) nullableLocationType
areTypesCompatible _ ( In . isNonNullType -> True ) = False
areTypesCompatible ( Full . TypeNonNull nonNullType ) locationType
| Full . NonNullTypeNamed namedType <- nonNullType =
areTypesCompatible ( Full . TypeNamed namedType ) locationType
| Full . NonNullTypeList namedType <- nonNullType =
areTypesCompatible ( Full . TypeList namedType ) locationType
areTypesCompatible variableType locationType
| Full . TypeList itemVariableType <- variableType
, In . ListType itemLocationType <- locationType =
areTypesCompatible itemVariableType itemLocationType
| areIdentical variableType locationType = True
| otherwise = False
areIdentical ( Full . TypeList typeList ) ( In . ListType itemLocationType ) =
areIdentical typeList itemLocationType
areIdentical ( Full . TypeNonNull nonNullType ) locationType
| Full . NonNullTypeList nonNullList <- nonNullType
, In . NonNullListType itemLocationType <- locationType =
areIdentical nonNullList itemLocationType
| Full . NonNullTypeNamed _ <- nonNullType
, In . ListBaseType _ <- locationType = False
| Full . NonNullTypeNamed nonNullList <- nonNullType
, In . isNonNullType locationType =
nonNullList == inputTypeName locationType
areIdentical ( Full . TypeNamed _ ) ( In . ListBaseType _ ) = False
areIdentical ( Full . TypeNamed typeNamed ) locationType
| not $ In . isNonNullType locationType =
typeNamed == inputTypeName locationType
areIdentical _ _ = False
hasNonNullVariableDefaultValue ( Just ( Full . Node Full . ConstNull _ ) ) = False
hasNonNullVariableDefaultValue Nothing = False
hasNonNullVariableDefaultValue _ = True
unwrapInType ( In . NonNullScalarType nonNullType ) =
Just $ In . NamedScalarType nonNullType
unwrapInType ( In . NonNullEnumType nonNullType ) =
Just $ In . NamedEnumType nonNullType
unwrapInType ( In . NonNullInputObjectType nonNullType ) =
Just $ In . NamedInputObjectType nonNullType
unwrapInType ( In . NonNullListType nonNullType ) =
Just $ In . ListType nonNullType
unwrapInType _ = Nothing
makeError ( Full . VariableDefinition variableName variableType _ location' ) expectedType = Error
{ message = concat
[ " Variable \ " $ "
, Text . unpack variableName
, " \ " of type \ " "
, show variableType
, " \ " used in position expecting type \ " "
, show expectedType
, " \ " . "
]
, locations = [ location' ]
}