Validate variable usages are allowed in arguments

This commit is contained in:
Eugen Wissner 2020-12-26 06:31:56 +01:00
parent 5a6709030c
commit 22abf7ca58
4 changed files with 238 additions and 10 deletions

View File

@ -11,6 +11,7 @@ and this project adheres to
- `Validate.Rules`: - `Validate.Rules`:
- `overlappingFieldsCanBeMergedRule` - `overlappingFieldsCanBeMergedRule`
- `possibleFragmentSpreadsRule` - `possibleFragmentSpreadsRule`
- `variablesInAllowedPositionRule`
- `Type.Schema.implementations` contains a map from interfaces and objects to - `Type.Schema.implementations` contains a map from interfaces and objects to
interfaces they implement. interfaces they implement.
- Show instances for GraphQL type definitions in the `Type` modules. - Show instances for GraphQL type definitions in the `Type` modules.

View File

@ -37,6 +37,7 @@ module Language.GraphQL.Validate.Rules
, uniqueInputFieldNamesRule , uniqueInputFieldNamesRule
, uniqueOperationNamesRule , uniqueOperationNamesRule
, uniqueVariableNamesRule , uniqueVariableNamesRule
, variablesInAllowedPositionRule
, variablesAreInputTypesRule , variablesAreInputTypesRule
) where ) where
@ -45,13 +46,13 @@ import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT)
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
import Data.Bifunctor (first) import Data.Bifunctor (first)
import Data.Foldable (find, foldl', toList) import Data.Foldable (find, fold, foldl', toList)
import qualified Data.HashMap.Strict as HashMap import qualified Data.HashMap.Strict as HashMap
import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict (HashMap)
import Data.HashSet (HashSet) import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet import qualified Data.HashSet as HashSet
import Data.List (groupBy, sortBy, sortOn) import Data.List (groupBy, sortBy, sortOn)
import Data.Maybe (fromMaybe, isNothing, mapMaybe) import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
import Data.List.NonEmpty (NonEmpty(..)) import Data.List.NonEmpty (NonEmpty(..))
import Data.Ord (comparing) import Data.Ord (comparing)
import Data.Sequence (Seq(..), (|>)) import Data.Sequence (Seq(..), (|>))
@ -108,6 +109,7 @@ specifiedRules =
, variablesAreInputTypesRule , variablesAreInputTypesRule
, noUndefinedVariablesRule , noUndefinedVariablesRule
, noUnusedVariablesRule , noUnusedVariablesRule
, variablesInAllowedPositionRule
] ]
-- | Definition must be OperationDefinition or FragmentDefinition. -- | Definition must be OperationDefinition or FragmentDefinition.
@ -651,12 +653,9 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas
lift $ lift $ mapArguments arguments <> mapDirectives directives' lift $ lift $ mapArguments arguments <> mapDirectives directives'
variableFilter (Full.FragmentSpreadSelection spread) variableFilter (Full.FragmentSpreadSelection spread)
| Full.FragmentSpread fragmentName _ _ <- spread = do | Full.FragmentSpread fragmentName _ _ <- spread = do
definitions <- lift $ asks ast nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName
visited <- gets (HashSet.member fragmentName) case nonVisitedFragmentDefinition of
modify (HashSet.insert fragmentName) Just fragmentDefinition -> diveIntoSpread fragmentDefinition
case find (isSpreadTarget fragmentName) definitions of
Just (viewFragment -> Just fragmentDefinition)
| not visited -> diveIntoSpread fragmentDefinition
_ -> lift $ lift mempty _ -> lift $ lift mempty
diveIntoSpread (Full.FragmentDefinition _ _ directives' selections _) diveIntoSpread (Full.FragmentDefinition _ _ directives' selections _)
= filterSelections' selections = filterSelections' selections
@ -1286,3 +1285,199 @@ findSpreadTarget fragmentName = do
let Full.FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition let Full.FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition
in Just typeCondition in Just typeCondition
extractTypeCondition _ = Nothing extractTypeCondition _ = Nothing
visitFragmentDefinition :: forall m
. Text
-> ValidationState m (Maybe Full.FragmentDefinition)
visitFragmentDefinition fragmentName = do
definitions <- lift $ asks ast
visited <- gets (HashSet.member fragmentName)
modify (HashSet.insert fragmentName)
case find (isSpreadTarget fragmentName) definitions of
Just (viewFragment -> Just fragmentDefinition)
| not visited -> pure $ Just fragmentDefinition
_ -> pure Nothing
-- | Variable usages must be compatible with the arguments they are passed to.
--
-- Validation failures occur when variables are used in the context of types
-- that are complete mismatches, or if a nullable type in a variable is passed
-- to a nonnull argument type.
variablesInAllowedPositionRule :: forall m. Rule m
variablesInAllowedPositionRule = OperationDefinitionRule $ \case
Full.OperationDefinition operationType _ variables _ selectionSet _ -> do
schema' <- asks schema
let root = go variables (toList selectionSet) . Type.CompositeObjectType
case operationType of
Full.Query -> root $ Schema.query schema'
Full.Mutation
| Just objectType <- Schema.mutation schema' -> root objectType
Full.Subscription
| Just objectType <- Schema.mutation schema' -> root objectType
_ -> lift mempty
_ -> lift mempty
where
go variables selections selectionType = mapReaderT (foldr (<>) Seq.empty)
$ flip evalStateT HashSet.empty
$ visitSelectionSet variables selectionType
$ toList selections
visitSelectionSet :: Foldable t
=> [Full.VariableDefinition]
-> Type.CompositeType m
-> t Full.Selection
-> ValidationState m (Seq Error)
visitSelectionSet variables selectionType selections =
foldM (evaluateSelection variables selectionType) mempty selections
evaluateSelection :: [Full.VariableDefinition]
-> Type.CompositeType m
-> Seq Error
-> Full.Selection
-> ValidationState m (Seq Error)
evaluateSelection variables selectionType accumulator selection
| Full.FragmentSpreadSelection spread <- selection
, Full.FragmentSpread fragmentName _ _ <- spread = do
types' <- lift $ asks $ Schema.types . schema
nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName
case nonVisitedFragmentDefinition of
Just fragmentDefinition
| Full.FragmentDefinition _ typeCondition _ _ _ <- fragmentDefinition
, Just spreadType <- Type.lookupTypeCondition typeCondition types' -> do
a <- spreadVariables variables spread
b <- diveIntoSpread variables spreadType fragmentDefinition
pure $ accumulator <> a <> b
_ -> lift $ lift mempty
| Full.FieldSelection fieldSelection <- selection
, Full.Field _ fieldName _ _ subselections _ <- fieldSelection =
case Type.lookupCompositeField fieldName selectionType of
Just (Out.Field _ typeField argumentTypes) -> do
a <- fieldVariables variables argumentTypes fieldSelection
case Type.outToComposite typeField of
Just newParentType -> do
b <- foldM (evaluateSelection variables newParentType) accumulator subselections
pure $ accumulator <> a <> b
Nothing -> pure $ accumulator <> a
Nothing -> pure accumulator
| Full.InlineFragmentSelection inlineSelection <- selection
, Full.InlineFragment typeCondition _ subselections _ <- inlineSelection = do
types' <- lift $ asks $ Schema.types . schema
let inlineType = fromMaybe selectionType
$ typeCondition
>>= flip Type.lookupTypeCondition types'
a <- inlineVariables variables inlineSelection
b <- foldM (evaluateSelection variables inlineType) accumulator subselections
pure $ accumulator <> a <> b
inlineVariables variables inline
| Full.InlineFragment _ directives' _ _ <- inline =
mapDirectives variables directives'
fieldVariables :: [Full.VariableDefinition]
-> In.Arguments
-> Full.Field
-> ValidationState m (Seq Error)
fieldVariables variables argumentTypes fieldSelection = do
let Full.Field _ _ arguments directives' _ _ = fieldSelection
argumentErrors <- mapArguments variables argumentTypes arguments
directiveErrors <- mapDirectives variables directives'
pure $ argumentErrors <> directiveErrors
spreadVariables variables (Full.FragmentSpread _ directives' _) =
mapDirectives variables directives'
diveIntoSpread variables fieldType fragmentDefinition = do
let Full.FragmentDefinition _ _ directives' selections _ =
fragmentDefinition
selectionErrors <- visitSelectionSet variables fieldType selections
directiveErrors <- mapDirectives variables directives'
pure $ selectionErrors <> directiveErrors
findDirectiveVariables variables directive = do
let Full.Directive directiveName arguments _ = directive
directiveDefinitions <- lift $ asks $ Schema.directives . schema
case HashMap.lookup directiveName directiveDefinitions of
Just (Schema.Directive _ _ directiveArguments) ->
mapArguments variables directiveArguments arguments
Nothing -> pure mempty
mapArguments variables argumentTypes
= fmap (Seq.fromList . catMaybes)
. traverse (findArgumentVariables variables argumentTypes)
mapDirectives variables = fmap fold
<$> traverse (findDirectiveVariables variables)
findArgumentVariables variables argumentTypes argument
| Full.Argument argumentName argumentValue _ <- argument
, Full.Node{ node = Full.Variable variableName, ..} <- argumentValue
, Just expectedType <- HashMap.lookup argumentName argumentTypes
, findVariableDefinition' <- findVariableDefinition variableName
, Just variableDefinition <- find findVariableDefinition' variables =
isVariableUsageAllowed expectedType variableDefinition
| otherwise = pure Nothing
findVariableDefinition variableName variableDefinition =
let Full.VariableDefinition variableName' _ _ _ = variableDefinition
in variableName == variableName'
isVariableUsageAllowed (In.Argument _ locationType locationDefaultValue) variableDefinition@(Full.VariableDefinition _ variableType variableDefaultValue _)
| Full.TypeNonNull _ <- variableType =
typesCompatibleOrError variableDefinition locationType
| Just nullableLocationType <- unwrapInType locationType
, hasNonNullVariableDefaultValue' <- hasNonNullVariableDefaultValue variableDefaultValue
, hasLocationDefaultValue <- isJust locationDefaultValue =
if hasNonNullVariableDefaultValue' || hasLocationDefaultValue
then typesCompatibleOrError variableDefinition nullableLocationType
else pure $ Just $ makeError variableDefinition locationType
| otherwise = typesCompatibleOrError variableDefinition locationType
typesCompatibleOrError variableDefinition locationType
| Full.VariableDefinition _ variableType _ _ <- variableDefinition
, areTypesCompatible variableType locationType = pure Nothing
| otherwise = pure $ Just $ makeError variableDefinition locationType
areTypesCompatible (Full.TypeNonNull nonNullType) (unwrapInType -> Just nullableLocationType) =
case nonNullType of
Full.NonNullTypeNamed n ->
areTypesCompatible (Full.TypeNamed n) nullableLocationType
Full.NonNullTypeList n ->
areTypesCompatible (Full.TypeList n) nullableLocationType
areTypesCompatible _ (In.isNonNullType -> True) = False
areTypesCompatible (Full.TypeNonNull nonNullType) locationType
| Full.NonNullTypeNamed namedType <- nonNullType =
areTypesCompatible (Full.TypeNamed namedType) locationType
| Full.NonNullTypeList namedType <- nonNullType =
areTypesCompatible (Full.TypeList namedType) locationType
areTypesCompatible variableType locationType
| Full.TypeList itemVariableType <- variableType
, In.ListType itemLocationType <- locationType =
areTypesCompatible itemVariableType itemLocationType
| areIdentical variableType locationType = True
| otherwise = False
areIdentical (Full.TypeList typeList) (In.ListType itemLocationType) =
areIdentical typeList itemLocationType
areIdentical (Full.TypeNonNull nonNullType) locationType
| Full.NonNullTypeList nonNullList <- nonNullType
, In.NonNullListType itemLocationType <- locationType =
areIdentical nonNullList itemLocationType
| Full.NonNullTypeNamed _ <- nonNullType
, In.ListBaseType _ <- locationType = False
| Full.NonNullTypeNamed nonNullList <- nonNullType
, In.isNonNullType locationType =
nonNullList == inputTypeName locationType
areIdentical (Full.TypeNamed _) (In.ListBaseType _) = False
areIdentical (Full.TypeNamed typeNamed) locationType
| not $ In.isNonNullType locationType =
typeNamed == inputTypeName locationType
areIdentical _ _ = False
hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False
hasNonNullVariableDefaultValue Nothing = False
hasNonNullVariableDefaultValue _ = True
unwrapInType (In.NonNullScalarType nonNullType) =
Just $ In.NamedScalarType nonNullType
unwrapInType (In.NonNullEnumType nonNullType) =
Just $ In.NamedEnumType nonNullType
unwrapInType (In.NonNullInputObjectType nonNullType) =
Just $ In.NamedInputObjectType nonNullType
unwrapInType (In.NonNullListType nonNullType) =
Just $ In.ListType nonNullType
unwrapInType _ = Nothing
makeError (Full.VariableDefinition variableName variableType _ location') expectedType = Error
{ message = concat
[ "Variable \"$"
, Text.unpack variableName
, "\" of type \""
, show variableType
, "\" used in position expecting type \""
, show expectedType
, "\"."
]
, locations = [location']
}

View File

@ -1,4 +1,4 @@
resolver: lts-16.26 resolver: lts-16.27
packages: packages:
- . - .

View File

@ -485,7 +485,7 @@ spec =
"Variable \"$dog\" cannot be non-input type \"Dog\"." "Variable \"$dog\" cannot be non-input type \"Dog\"."
, locations = [AST.Location 2 34] , locations = [AST.Location 2 34]
} }
in validate queryString `shouldBe` [expected] in validate queryString `shouldContain` [expected]
it "rejects undefined variables" $ it "rejects undefined variables" $
let queryString = [r| let queryString = [r|
@ -808,3 +808,35 @@ spec =
, locations = [AST.Location 4 19] , locations = [AST.Location 4 19]
} }
in validate queryString `shouldBe` [expected] in validate queryString `shouldBe` [expected]
it "wrongly typed variable arguments" $
let queryString = [r|
query catCommandArgQuery($catCommandArg: CatCommand) {
cat {
doesKnowCommand(catCommand: $catCommandArg)
}
}
|]
expected = Error
{ message =
"Variable \"$catCommandArg\" of type \"CatCommand\" \
\used in position expecting type \"!CatCommand\"."
, locations = [AST.Location 2 40]
}
in validate queryString `shouldBe` [expected]
it "wrongly typed variable arguments" $
let queryString = [r|
query intCannotGoIntoBoolean($intArg: Int) {
dog {
isHousetrained(atOtherHomes: $intArg)
}
}
|]
expected = Error
{ message =
"Variable \"$intArg\" of type \"Int\" used in position \
\expecting type \"Boolean\"."
, locations = [AST.Location 2 44]
}
in validate queryString `shouldBe` [expected]