forked from OSS/graphql
Validate variable usages are allowed in arguments
This commit is contained in:
parent
5a6709030c
commit
22abf7ca58
@ -11,6 +11,7 @@ and this project adheres to
|
|||||||
- `Validate.Rules`:
|
- `Validate.Rules`:
|
||||||
- `overlappingFieldsCanBeMergedRule`
|
- `overlappingFieldsCanBeMergedRule`
|
||||||
- `possibleFragmentSpreadsRule`
|
- `possibleFragmentSpreadsRule`
|
||||||
|
- `variablesInAllowedPositionRule`
|
||||||
- `Type.Schema.implementations` contains a map from interfaces and objects to
|
- `Type.Schema.implementations` contains a map from interfaces and objects to
|
||||||
interfaces they implement.
|
interfaces they implement.
|
||||||
- Show instances for GraphQL type definitions in the `Type` modules.
|
- Show instances for GraphQL type definitions in the `Type` modules.
|
||||||
|
@ -37,6 +37,7 @@ module Language.GraphQL.Validate.Rules
|
|||||||
, uniqueInputFieldNamesRule
|
, uniqueInputFieldNamesRule
|
||||||
, uniqueOperationNamesRule
|
, uniqueOperationNamesRule
|
||||||
, uniqueVariableNamesRule
|
, uniqueVariableNamesRule
|
||||||
|
, variablesInAllowedPositionRule
|
||||||
, variablesAreInputTypesRule
|
, variablesAreInputTypesRule
|
||||||
) where
|
) where
|
||||||
|
|
||||||
@ -45,13 +46,13 @@ import Control.Monad.Trans.Class (MonadTrans(..))
|
|||||||
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT)
|
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT)
|
||||||
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
|
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
|
||||||
import Data.Bifunctor (first)
|
import Data.Bifunctor (first)
|
||||||
import Data.Foldable (find, foldl', toList)
|
import Data.Foldable (find, fold, foldl', toList)
|
||||||
import qualified Data.HashMap.Strict as HashMap
|
import qualified Data.HashMap.Strict as HashMap
|
||||||
import Data.HashMap.Strict (HashMap)
|
import Data.HashMap.Strict (HashMap)
|
||||||
import Data.HashSet (HashSet)
|
import Data.HashSet (HashSet)
|
||||||
import qualified Data.HashSet as HashSet
|
import qualified Data.HashSet as HashSet
|
||||||
import Data.List (groupBy, sortBy, sortOn)
|
import Data.List (groupBy, sortBy, sortOn)
|
||||||
import Data.Maybe (fromMaybe, isNothing, mapMaybe)
|
import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
|
||||||
import Data.List.NonEmpty (NonEmpty(..))
|
import Data.List.NonEmpty (NonEmpty(..))
|
||||||
import Data.Ord (comparing)
|
import Data.Ord (comparing)
|
||||||
import Data.Sequence (Seq(..), (|>))
|
import Data.Sequence (Seq(..), (|>))
|
||||||
@ -108,6 +109,7 @@ specifiedRules =
|
|||||||
, variablesAreInputTypesRule
|
, variablesAreInputTypesRule
|
||||||
, noUndefinedVariablesRule
|
, noUndefinedVariablesRule
|
||||||
, noUnusedVariablesRule
|
, noUnusedVariablesRule
|
||||||
|
, variablesInAllowedPositionRule
|
||||||
]
|
]
|
||||||
|
|
||||||
-- | Definition must be OperationDefinition or FragmentDefinition.
|
-- | Definition must be OperationDefinition or FragmentDefinition.
|
||||||
@ -651,12 +653,9 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas
|
|||||||
lift $ lift $ mapArguments arguments <> mapDirectives directives'
|
lift $ lift $ mapArguments arguments <> mapDirectives directives'
|
||||||
variableFilter (Full.FragmentSpreadSelection spread)
|
variableFilter (Full.FragmentSpreadSelection spread)
|
||||||
| Full.FragmentSpread fragmentName _ _ <- spread = do
|
| Full.FragmentSpread fragmentName _ _ <- spread = do
|
||||||
definitions <- lift $ asks ast
|
nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName
|
||||||
visited <- gets (HashSet.member fragmentName)
|
case nonVisitedFragmentDefinition of
|
||||||
modify (HashSet.insert fragmentName)
|
Just fragmentDefinition -> diveIntoSpread fragmentDefinition
|
||||||
case find (isSpreadTarget fragmentName) definitions of
|
|
||||||
Just (viewFragment -> Just fragmentDefinition)
|
|
||||||
| not visited -> diveIntoSpread fragmentDefinition
|
|
||||||
_ -> lift $ lift mempty
|
_ -> lift $ lift mempty
|
||||||
diveIntoSpread (Full.FragmentDefinition _ _ directives' selections _)
|
diveIntoSpread (Full.FragmentDefinition _ _ directives' selections _)
|
||||||
= filterSelections' selections
|
= filterSelections' selections
|
||||||
@ -1286,3 +1285,199 @@ findSpreadTarget fragmentName = do
|
|||||||
let Full.FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition
|
let Full.FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition
|
||||||
in Just typeCondition
|
in Just typeCondition
|
||||||
extractTypeCondition _ = Nothing
|
extractTypeCondition _ = Nothing
|
||||||
|
|
||||||
|
visitFragmentDefinition :: forall m
|
||||||
|
. Text
|
||||||
|
-> ValidationState m (Maybe Full.FragmentDefinition)
|
||||||
|
visitFragmentDefinition fragmentName = do
|
||||||
|
definitions <- lift $ asks ast
|
||||||
|
visited <- gets (HashSet.member fragmentName)
|
||||||
|
modify (HashSet.insert fragmentName)
|
||||||
|
case find (isSpreadTarget fragmentName) definitions of
|
||||||
|
Just (viewFragment -> Just fragmentDefinition)
|
||||||
|
| not visited -> pure $ Just fragmentDefinition
|
||||||
|
_ -> pure Nothing
|
||||||
|
|
||||||
|
-- | Variable usages must be compatible with the arguments they are passed to.
|
||||||
|
--
|
||||||
|
-- Validation failures occur when variables are used in the context of types
|
||||||
|
-- that are complete mismatches, or if a nullable type in a variable is passed
|
||||||
|
-- to a non‐null argument type.
|
||||||
|
variablesInAllowedPositionRule :: forall m. Rule m
|
||||||
|
variablesInAllowedPositionRule = OperationDefinitionRule $ \case
|
||||||
|
Full.OperationDefinition operationType _ variables _ selectionSet _ -> do
|
||||||
|
schema' <- asks schema
|
||||||
|
let root = go variables (toList selectionSet) . Type.CompositeObjectType
|
||||||
|
case operationType of
|
||||||
|
Full.Query -> root $ Schema.query schema'
|
||||||
|
Full.Mutation
|
||||||
|
| Just objectType <- Schema.mutation schema' -> root objectType
|
||||||
|
Full.Subscription
|
||||||
|
| Just objectType <- Schema.mutation schema' -> root objectType
|
||||||
|
_ -> lift mempty
|
||||||
|
_ -> lift mempty
|
||||||
|
where
|
||||||
|
go variables selections selectionType = mapReaderT (foldr (<>) Seq.empty)
|
||||||
|
$ flip evalStateT HashSet.empty
|
||||||
|
$ visitSelectionSet variables selectionType
|
||||||
|
$ toList selections
|
||||||
|
visitSelectionSet :: Foldable t
|
||||||
|
=> [Full.VariableDefinition]
|
||||||
|
-> Type.CompositeType m
|
||||||
|
-> t Full.Selection
|
||||||
|
-> ValidationState m (Seq Error)
|
||||||
|
visitSelectionSet variables selectionType selections =
|
||||||
|
foldM (evaluateSelection variables selectionType) mempty selections
|
||||||
|
evaluateSelection :: [Full.VariableDefinition]
|
||||||
|
-> Type.CompositeType m
|
||||||
|
-> Seq Error
|
||||||
|
-> Full.Selection
|
||||||
|
-> ValidationState m (Seq Error)
|
||||||
|
evaluateSelection variables selectionType accumulator selection
|
||||||
|
| Full.FragmentSpreadSelection spread <- selection
|
||||||
|
, Full.FragmentSpread fragmentName _ _ <- spread = do
|
||||||
|
types' <- lift $ asks $ Schema.types . schema
|
||||||
|
nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName
|
||||||
|
case nonVisitedFragmentDefinition of
|
||||||
|
Just fragmentDefinition
|
||||||
|
| Full.FragmentDefinition _ typeCondition _ _ _ <- fragmentDefinition
|
||||||
|
, Just spreadType <- Type.lookupTypeCondition typeCondition types' -> do
|
||||||
|
a <- spreadVariables variables spread
|
||||||
|
b <- diveIntoSpread variables spreadType fragmentDefinition
|
||||||
|
pure $ accumulator <> a <> b
|
||||||
|
_ -> lift $ lift mempty
|
||||||
|
| Full.FieldSelection fieldSelection <- selection
|
||||||
|
, Full.Field _ fieldName _ _ subselections _ <- fieldSelection =
|
||||||
|
case Type.lookupCompositeField fieldName selectionType of
|
||||||
|
Just (Out.Field _ typeField argumentTypes) -> do
|
||||||
|
a <- fieldVariables variables argumentTypes fieldSelection
|
||||||
|
case Type.outToComposite typeField of
|
||||||
|
Just newParentType -> do
|
||||||
|
b <- foldM (evaluateSelection variables newParentType) accumulator subselections
|
||||||
|
pure $ accumulator <> a <> b
|
||||||
|
Nothing -> pure $ accumulator <> a
|
||||||
|
Nothing -> pure accumulator
|
||||||
|
| Full.InlineFragmentSelection inlineSelection <- selection
|
||||||
|
, Full.InlineFragment typeCondition _ subselections _ <- inlineSelection = do
|
||||||
|
types' <- lift $ asks $ Schema.types . schema
|
||||||
|
let inlineType = fromMaybe selectionType
|
||||||
|
$ typeCondition
|
||||||
|
>>= flip Type.lookupTypeCondition types'
|
||||||
|
a <- inlineVariables variables inlineSelection
|
||||||
|
b <- foldM (evaluateSelection variables inlineType) accumulator subselections
|
||||||
|
pure $ accumulator <> a <> b
|
||||||
|
inlineVariables variables inline
|
||||||
|
| Full.InlineFragment _ directives' _ _ <- inline =
|
||||||
|
mapDirectives variables directives'
|
||||||
|
fieldVariables :: [Full.VariableDefinition]
|
||||||
|
-> In.Arguments
|
||||||
|
-> Full.Field
|
||||||
|
-> ValidationState m (Seq Error)
|
||||||
|
fieldVariables variables argumentTypes fieldSelection = do
|
||||||
|
let Full.Field _ _ arguments directives' _ _ = fieldSelection
|
||||||
|
argumentErrors <- mapArguments variables argumentTypes arguments
|
||||||
|
directiveErrors <- mapDirectives variables directives'
|
||||||
|
pure $ argumentErrors <> directiveErrors
|
||||||
|
spreadVariables variables (Full.FragmentSpread _ directives' _) =
|
||||||
|
mapDirectives variables directives'
|
||||||
|
diveIntoSpread variables fieldType fragmentDefinition = do
|
||||||
|
let Full.FragmentDefinition _ _ directives' selections _ =
|
||||||
|
fragmentDefinition
|
||||||
|
selectionErrors <- visitSelectionSet variables fieldType selections
|
||||||
|
directiveErrors <- mapDirectives variables directives'
|
||||||
|
pure $ selectionErrors <> directiveErrors
|
||||||
|
findDirectiveVariables variables directive = do
|
||||||
|
let Full.Directive directiveName arguments _ = directive
|
||||||
|
directiveDefinitions <- lift $ asks $ Schema.directives . schema
|
||||||
|
case HashMap.lookup directiveName directiveDefinitions of
|
||||||
|
Just (Schema.Directive _ _ directiveArguments) ->
|
||||||
|
mapArguments variables directiveArguments arguments
|
||||||
|
Nothing -> pure mempty
|
||||||
|
mapArguments variables argumentTypes
|
||||||
|
= fmap (Seq.fromList . catMaybes)
|
||||||
|
. traverse (findArgumentVariables variables argumentTypes)
|
||||||
|
mapDirectives variables = fmap fold
|
||||||
|
<$> traverse (findDirectiveVariables variables)
|
||||||
|
findArgumentVariables variables argumentTypes argument
|
||||||
|
| Full.Argument argumentName argumentValue _ <- argument
|
||||||
|
, Full.Node{ node = Full.Variable variableName, ..} <- argumentValue
|
||||||
|
, Just expectedType <- HashMap.lookup argumentName argumentTypes
|
||||||
|
, findVariableDefinition' <- findVariableDefinition variableName
|
||||||
|
, Just variableDefinition <- find findVariableDefinition' variables =
|
||||||
|
isVariableUsageAllowed expectedType variableDefinition
|
||||||
|
| otherwise = pure Nothing
|
||||||
|
findVariableDefinition variableName variableDefinition =
|
||||||
|
let Full.VariableDefinition variableName' _ _ _ = variableDefinition
|
||||||
|
in variableName == variableName'
|
||||||
|
isVariableUsageAllowed (In.Argument _ locationType locationDefaultValue) variableDefinition@(Full.VariableDefinition _ variableType variableDefaultValue _)
|
||||||
|
| Full.TypeNonNull _ <- variableType =
|
||||||
|
typesCompatibleOrError variableDefinition locationType
|
||||||
|
| Just nullableLocationType <- unwrapInType locationType
|
||||||
|
, hasNonNullVariableDefaultValue' <- hasNonNullVariableDefaultValue variableDefaultValue
|
||||||
|
, hasLocationDefaultValue <- isJust locationDefaultValue =
|
||||||
|
if hasNonNullVariableDefaultValue' || hasLocationDefaultValue
|
||||||
|
then typesCompatibleOrError variableDefinition nullableLocationType
|
||||||
|
else pure $ Just $ makeError variableDefinition locationType
|
||||||
|
| otherwise = typesCompatibleOrError variableDefinition locationType
|
||||||
|
typesCompatibleOrError variableDefinition locationType
|
||||||
|
| Full.VariableDefinition _ variableType _ _ <- variableDefinition
|
||||||
|
, areTypesCompatible variableType locationType = pure Nothing
|
||||||
|
| otherwise = pure $ Just $ makeError variableDefinition locationType
|
||||||
|
areTypesCompatible (Full.TypeNonNull nonNullType) (unwrapInType -> Just nullableLocationType) =
|
||||||
|
case nonNullType of
|
||||||
|
Full.NonNullTypeNamed n ->
|
||||||
|
areTypesCompatible (Full.TypeNamed n) nullableLocationType
|
||||||
|
Full.NonNullTypeList n ->
|
||||||
|
areTypesCompatible (Full.TypeList n) nullableLocationType
|
||||||
|
areTypesCompatible _ (In.isNonNullType -> True) = False
|
||||||
|
areTypesCompatible (Full.TypeNonNull nonNullType) locationType
|
||||||
|
| Full.NonNullTypeNamed namedType <- nonNullType =
|
||||||
|
areTypesCompatible (Full.TypeNamed namedType) locationType
|
||||||
|
| Full.NonNullTypeList namedType <- nonNullType =
|
||||||
|
areTypesCompatible (Full.TypeList namedType) locationType
|
||||||
|
areTypesCompatible variableType locationType
|
||||||
|
| Full.TypeList itemVariableType <- variableType
|
||||||
|
, In.ListType itemLocationType <- locationType =
|
||||||
|
areTypesCompatible itemVariableType itemLocationType
|
||||||
|
| areIdentical variableType locationType = True
|
||||||
|
| otherwise = False
|
||||||
|
areIdentical (Full.TypeList typeList) (In.ListType itemLocationType) =
|
||||||
|
areIdentical typeList itemLocationType
|
||||||
|
areIdentical (Full.TypeNonNull nonNullType) locationType
|
||||||
|
| Full.NonNullTypeList nonNullList <- nonNullType
|
||||||
|
, In.NonNullListType itemLocationType <- locationType =
|
||||||
|
areIdentical nonNullList itemLocationType
|
||||||
|
| Full.NonNullTypeNamed _ <- nonNullType
|
||||||
|
, In.ListBaseType _ <- locationType = False
|
||||||
|
| Full.NonNullTypeNamed nonNullList <- nonNullType
|
||||||
|
, In.isNonNullType locationType =
|
||||||
|
nonNullList == inputTypeName locationType
|
||||||
|
areIdentical (Full.TypeNamed _) (In.ListBaseType _) = False
|
||||||
|
areIdentical (Full.TypeNamed typeNamed) locationType
|
||||||
|
| not $ In.isNonNullType locationType =
|
||||||
|
typeNamed == inputTypeName locationType
|
||||||
|
areIdentical _ _ = False
|
||||||
|
hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False
|
||||||
|
hasNonNullVariableDefaultValue Nothing = False
|
||||||
|
hasNonNullVariableDefaultValue _ = True
|
||||||
|
unwrapInType (In.NonNullScalarType nonNullType) =
|
||||||
|
Just $ In.NamedScalarType nonNullType
|
||||||
|
unwrapInType (In.NonNullEnumType nonNullType) =
|
||||||
|
Just $ In.NamedEnumType nonNullType
|
||||||
|
unwrapInType (In.NonNullInputObjectType nonNullType) =
|
||||||
|
Just $ In.NamedInputObjectType nonNullType
|
||||||
|
unwrapInType (In.NonNullListType nonNullType) =
|
||||||
|
Just $ In.ListType nonNullType
|
||||||
|
unwrapInType _ = Nothing
|
||||||
|
makeError (Full.VariableDefinition variableName variableType _ location') expectedType = Error
|
||||||
|
{ message = concat
|
||||||
|
[ "Variable \"$"
|
||||||
|
, Text.unpack variableName
|
||||||
|
, "\" of type \""
|
||||||
|
, show variableType
|
||||||
|
, "\" used in position expecting type \""
|
||||||
|
, show expectedType
|
||||||
|
, "\"."
|
||||||
|
]
|
||||||
|
, locations = [location']
|
||||||
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
resolver: lts-16.26
|
resolver: lts-16.27
|
||||||
|
|
||||||
packages:
|
packages:
|
||||||
- .
|
- .
|
||||||
|
@ -485,7 +485,7 @@ spec =
|
|||||||
"Variable \"$dog\" cannot be non-input type \"Dog\"."
|
"Variable \"$dog\" cannot be non-input type \"Dog\"."
|
||||||
, locations = [AST.Location 2 34]
|
, locations = [AST.Location 2 34]
|
||||||
}
|
}
|
||||||
in validate queryString `shouldBe` [expected]
|
in validate queryString `shouldContain` [expected]
|
||||||
|
|
||||||
it "rejects undefined variables" $
|
it "rejects undefined variables" $
|
||||||
let queryString = [r|
|
let queryString = [r|
|
||||||
@ -808,3 +808,35 @@ spec =
|
|||||||
, locations = [AST.Location 4 19]
|
, locations = [AST.Location 4 19]
|
||||||
}
|
}
|
||||||
in validate queryString `shouldBe` [expected]
|
in validate queryString `shouldBe` [expected]
|
||||||
|
|
||||||
|
it "wrongly typed variable arguments" $
|
||||||
|
let queryString = [r|
|
||||||
|
query catCommandArgQuery($catCommandArg: CatCommand) {
|
||||||
|
cat {
|
||||||
|
doesKnowCommand(catCommand: $catCommandArg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|]
|
||||||
|
expected = Error
|
||||||
|
{ message =
|
||||||
|
"Variable \"$catCommandArg\" of type \"CatCommand\" \
|
||||||
|
\used in position expecting type \"!CatCommand\"."
|
||||||
|
, locations = [AST.Location 2 40]
|
||||||
|
}
|
||||||
|
in validate queryString `shouldBe` [expected]
|
||||||
|
|
||||||
|
it "wrongly typed variable arguments" $
|
||||||
|
let queryString = [r|
|
||||||
|
query intCannotGoIntoBoolean($intArg: Int) {
|
||||||
|
dog {
|
||||||
|
isHousetrained(atOtherHomes: $intArg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|]
|
||||||
|
expected = Error
|
||||||
|
{ message =
|
||||||
|
"Variable \"$intArg\" of type \"Int\" used in position \
|
||||||
|
\expecting type \"Boolean\"."
|
||||||
|
, locations = [AST.Location 2 44]
|
||||||
|
}
|
||||||
|
in validate queryString `shouldBe` [expected]
|
||||||
|
Loading…
Reference in New Issue
Block a user