summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEugen Wissner <belka@caraus.de>2020-12-26 06:31:56 +0100
committerEugen Wissner <belka@caraus.de>2020-12-26 06:31:56 +0100
commit22abf7ca58091a521de037bd9b22689e7309b8ba (patch)
tree051ddbce9291d65f29dcae3b4da4255ef019bd91
parent5a6709030ceee63adb417c0fa2d2abce24c5d5cb (diff)
downloadgraphql-22abf7ca58091a521de037bd9b22689e7309b8ba.tar.gz
Validate variable usages are allowed in arguments
-rw-r--r--CHANGELOG.md1
-rw-r--r--src/Language/GraphQL/Validate/Rules.hs211
-rw-r--r--stack.yaml2
-rw-r--r--tests/Language/GraphQL/ValidateSpec.hs34
4 files changed, 238 insertions, 10 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d35a35a..817357e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,6 +11,7 @@ and this project adheres to
- `Validate.Rules`:
- `overlappingFieldsCanBeMergedRule`
- `possibleFragmentSpreadsRule`
+ - `variablesInAllowedPositionRule`
- `Type.Schema.implementations` contains a map from interfaces and objects to
interfaces they implement.
- Show instances for GraphQL type definitions in the `Type` modules.
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs
index 56c839e..a159983 100644
--- a/src/Language/GraphQL/Validate/Rules.hs
+++ b/src/Language/GraphQL/Validate/Rules.hs
@@ -37,6 +37,7 @@ module Language.GraphQL.Validate.Rules
, uniqueInputFieldNamesRule
, uniqueOperationNamesRule
, uniqueVariableNamesRule
+ , variablesInAllowedPositionRule
, variablesAreInputTypesRule
) where
@@ -45,13 +46,13 @@ import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT)
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
import Data.Bifunctor (first)
-import Data.Foldable (find, foldl', toList)
+import Data.Foldable (find, fold, foldl', toList)
import qualified Data.HashMap.Strict as HashMap
import Data.HashMap.Strict (HashMap)
import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
import Data.List (groupBy, sortBy, sortOn)
-import Data.Maybe (fromMaybe, isNothing, mapMaybe)
+import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe)
import Data.List.NonEmpty (NonEmpty(..))
import Data.Ord (comparing)
import Data.Sequence (Seq(..), (|>))
@@ -108,6 +109,7 @@ specifiedRules =
, variablesAreInputTypesRule
, noUndefinedVariablesRule
, noUnusedVariablesRule
+ , variablesInAllowedPositionRule
]
-- | Definition must be OperationDefinition or FragmentDefinition.
@@ -651,12 +653,9 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas
lift $ lift $ mapArguments arguments <> mapDirectives directives'
variableFilter (Full.FragmentSpreadSelection spread)
| Full.FragmentSpread fragmentName _ _ <- spread = do
- definitions <- lift $ asks ast
- visited <- gets (HashSet.member fragmentName)
- modify (HashSet.insert fragmentName)
- case find (isSpreadTarget fragmentName) definitions of
- Just (viewFragment -> Just fragmentDefinition)
- | not visited -> diveIntoSpread fragmentDefinition
+ nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName
+ case nonVisitedFragmentDefinition of
+ Just fragmentDefinition -> diveIntoSpread fragmentDefinition
_ -> lift $ lift mempty
diveIntoSpread (Full.FragmentDefinition _ _ directives' selections _)
= filterSelections' selections
@@ -1286,3 +1285,199 @@ findSpreadTarget fragmentName = do
let Full.FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition
in Just typeCondition
extractTypeCondition _ = Nothing
+
+visitFragmentDefinition :: forall m
+ . Text
+ -> ValidationState m (Maybe Full.FragmentDefinition)
+visitFragmentDefinition fragmentName = do
+ definitions <- lift $ asks ast
+ visited <- gets (HashSet.member fragmentName)
+ modify (HashSet.insert fragmentName)
+ case find (isSpreadTarget fragmentName) definitions of
+ Just (viewFragment -> Just fragmentDefinition)
+ | not visited -> pure $ Just fragmentDefinition
+ _ -> pure Nothing
+
+-- | Variable usages must be compatible with the arguments they are passed to.
+--
+-- Validation failures occur when variables are used in the context of types
+-- that are complete mismatches, or if a nullable type in a variable is passed
+-- to a non‐null argument type.
+variablesInAllowedPositionRule :: forall m. Rule m
+variablesInAllowedPositionRule = OperationDefinitionRule $ \case
+ Full.OperationDefinition operationType _ variables _ selectionSet _ -> do
+ schema' <- asks schema
+ let root = go variables (toList selectionSet) . Type.CompositeObjectType
+ case operationType of
+ Full.Query -> root $ Schema.query schema'
+ Full.Mutation
+ | Just objectType <- Schema.mutation schema' -> root objectType
+ Full.Subscription
+ | Just objectType <- Schema.mutation schema' -> root objectType
+ _ -> lift mempty
+ _ -> lift mempty
+ where
+ go variables selections selectionType = mapReaderT (foldr (<>) Seq.empty)
+ $ flip evalStateT HashSet.empty
+ $ visitSelectionSet variables selectionType
+ $ toList selections
+ visitSelectionSet :: Foldable t
+ => [Full.VariableDefinition]
+ -> Type.CompositeType m
+ -> t Full.Selection
+ -> ValidationState m (Seq Error)
+ visitSelectionSet variables selectionType selections =
+ foldM (evaluateSelection variables selectionType) mempty selections
+ evaluateSelection :: [Full.VariableDefinition]
+ -> Type.CompositeType m
+ -> Seq Error
+ -> Full.Selection
+ -> ValidationState m (Seq Error)
+ evaluateSelection variables selectionType accumulator selection
+ | Full.FragmentSpreadSelection spread <- selection
+ , Full.FragmentSpread fragmentName _ _ <- spread = do
+ types' <- lift $ asks $ Schema.types . schema
+ nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName
+ case nonVisitedFragmentDefinition of
+ Just fragmentDefinition
+ | Full.FragmentDefinition _ typeCondition _ _ _ <- fragmentDefinition
+ , Just spreadType <- Type.lookupTypeCondition typeCondition types' -> do
+ a <- spreadVariables variables spread
+ b <- diveIntoSpread variables spreadType fragmentDefinition
+ pure $ accumulator <> a <> b
+ _ -> lift $ lift mempty
+ | Full.FieldSelection fieldSelection <- selection
+ , Full.Field _ fieldName _ _ subselections _ <- fieldSelection =
+ case Type.lookupCompositeField fieldName selectionType of
+ Just (Out.Field _ typeField argumentTypes) -> do
+ a <- fieldVariables variables argumentTypes fieldSelection
+ case Type.outToComposite typeField of
+ Just newParentType -> do
+ b <- foldM (evaluateSelection variables newParentType) accumulator subselections
+ pure $ accumulator <> a <> b
+ Nothing -> pure $ accumulator <> a
+ Nothing -> pure accumulator
+ | Full.InlineFragmentSelection inlineSelection <- selection
+ , Full.InlineFragment typeCondition _ subselections _ <- inlineSelection = do
+ types' <- lift $ asks $ Schema.types . schema
+ let inlineType = fromMaybe selectionType
+ $ typeCondition
+ >>= flip Type.lookupTypeCondition types'
+ a <- inlineVariables variables inlineSelection
+ b <- foldM (evaluateSelection variables inlineType) accumulator subselections
+ pure $ accumulator <> a <> b
+ inlineVariables variables inline
+ | Full.InlineFragment _ directives' _ _ <- inline =
+ mapDirectives variables directives'
+ fieldVariables :: [Full.VariableDefinition]
+ -> In.Arguments
+ -> Full.Field
+ -> ValidationState m (Seq Error)
+ fieldVariables variables argumentTypes fieldSelection = do
+ let Full.Field _ _ arguments directives' _ _ = fieldSelection
+ argumentErrors <- mapArguments variables argumentTypes arguments
+ directiveErrors <- mapDirectives variables directives'
+ pure $ argumentErrors <> directiveErrors
+ spreadVariables variables (Full.FragmentSpread _ directives' _) =
+ mapDirectives variables directives'
+ diveIntoSpread variables fieldType fragmentDefinition = do
+ let Full.FragmentDefinition _ _ directives' selections _ =
+ fragmentDefinition
+ selectionErrors <- visitSelectionSet variables fieldType selections
+ directiveErrors <- mapDirectives variables directives'
+ pure $ selectionErrors <> directiveErrors
+ findDirectiveVariables variables directive = do
+ let Full.Directive directiveName arguments _ = directive
+ directiveDefinitions <- lift $ asks $ Schema.directives . schema
+ case HashMap.lookup directiveName directiveDefinitions of
+ Just (Schema.Directive _ _ directiveArguments) ->
+ mapArguments variables directiveArguments arguments
+ Nothing -> pure mempty
+ mapArguments variables argumentTypes
+ = fmap (Seq.fromList . catMaybes)
+ . traverse (findArgumentVariables variables argumentTypes)
+ mapDirectives variables = fmap fold
+ <$> traverse (findDirectiveVariables variables)
+ findArgumentVariables variables argumentTypes argument
+ | Full.Argument argumentName argumentValue _ <- argument
+ , Full.Node{ node = Full.Variable variableName, ..} <- argumentValue
+ , Just expectedType <- HashMap.lookup argumentName argumentTypes
+ , findVariableDefinition' <- findVariableDefinition variableName
+ , Just variableDefinition <- find findVariableDefinition' variables =
+ isVariableUsageAllowed expectedType variableDefinition
+ | otherwise = pure Nothing
+ findVariableDefinition variableName variableDefinition =
+ let Full.VariableDefinition variableName' _ _ _ = variableDefinition
+ in variableName == variableName'
+ isVariableUsageAllowed (In.Argument _ locationType locationDefaultValue) variableDefinition@(Full.VariableDefinition _ variableType variableDefaultValue _)
+ | Full.TypeNonNull _ <- variableType =
+ typesCompatibleOrError variableDefinition locationType
+ | Just nullableLocationType <- unwrapInType locationType
+ , hasNonNullVariableDefaultValue' <- hasNonNullVariableDefaultValue variableDefaultValue
+ , hasLocationDefaultValue <- isJust locationDefaultValue =
+ if hasNonNullVariableDefaultValue' || hasLocationDefaultValue
+ then typesCompatibleOrError variableDefinition nullableLocationType
+ else pure $ Just $ makeError variableDefinition locationType
+ | otherwise = typesCompatibleOrError variableDefinition locationType
+ typesCompatibleOrError variableDefinition locationType
+ | Full.VariableDefinition _ variableType _ _ <- variableDefinition
+ , areTypesCompatible variableType locationType = pure Nothing
+ | otherwise = pure $ Just $ makeError variableDefinition locationType
+ areTypesCompatible (Full.TypeNonNull nonNullType) (unwrapInType -> Just nullableLocationType) =
+ case nonNullType of
+ Full.NonNullTypeNamed n ->
+ areTypesCompatible (Full.TypeNamed n) nullableLocationType
+ Full.NonNullTypeList n ->
+ areTypesCompatible (Full.TypeList n) nullableLocationType
+ areTypesCompatible _ (In.isNonNullType -> True) = False
+ areTypesCompatible (Full.TypeNonNull nonNullType) locationType
+ | Full.NonNullTypeNamed namedType <- nonNullType =
+ areTypesCompatible (Full.TypeNamed namedType) locationType
+ | Full.NonNullTypeList namedType <- nonNullType =
+ areTypesCompatible (Full.TypeList namedType) locationType
+ areTypesCompatible variableType locationType
+ | Full.TypeList itemVariableType <- variableType
+ , In.ListType itemLocationType <- locationType =
+ areTypesCompatible itemVariableType itemLocationType
+ | areIdentical variableType locationType = True
+ | otherwise = False
+ areIdentical (Full.TypeList typeList) (In.ListType itemLocationType) =
+ areIdentical typeList itemLocationType
+ areIdentical (Full.TypeNonNull nonNullType) locationType
+ | Full.NonNullTypeList nonNullList <- nonNullType
+ , In.NonNullListType itemLocationType <- locationType =
+ areIdentical nonNullList itemLocationType
+ | Full.NonNullTypeNamed _ <- nonNullType
+ , In.ListBaseType _ <- locationType = False
+ | Full.NonNullTypeNamed nonNullList <- nonNullType
+ , In.isNonNullType locationType =
+ nonNullList == inputTypeName locationType
+ areIdentical (Full.TypeNamed _) (In.ListBaseType _) = False
+ areIdentical (Full.TypeNamed typeNamed) locationType
+ | not $ In.isNonNullType locationType =
+ typeNamed == inputTypeName locationType
+ areIdentical _ _ = False
+ hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False
+ hasNonNullVariableDefaultValue Nothing = False
+ hasNonNullVariableDefaultValue _ = True
+ unwrapInType (In.NonNullScalarType nonNullType) =
+ Just $ In.NamedScalarType nonNullType
+ unwrapInType (In.NonNullEnumType nonNullType) =
+ Just $ In.NamedEnumType nonNullType
+ unwrapInType (In.NonNullInputObjectType nonNullType) =
+ Just $ In.NamedInputObjectType nonNullType
+ unwrapInType (In.NonNullListType nonNullType) =
+ Just $ In.ListType nonNullType
+ unwrapInType _ = Nothing
+ makeError (Full.VariableDefinition variableName variableType _ location') expectedType = Error
+ { message = concat
+ [ "Variable \"$"
+ , Text.unpack variableName
+ , "\" of type \""
+ , show variableType
+ , "\" used in position expecting type \""
+ , show expectedType
+ , "\"."
+ ]
+ , locations = [location']
+ }
diff --git a/stack.yaml b/stack.yaml
index 9ee8fd3..be869c7 100644
--- a/stack.yaml
+++ b/stack.yaml
@@ -1,4 +1,4 @@
-resolver: lts-16.26
+resolver: lts-16.27
packages:
- .
diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs
index 4063b57..d340d4e 100644
--- a/tests/Language/GraphQL/ValidateSpec.hs
+++ b/tests/Language/GraphQL/ValidateSpec.hs
@@ -485,7 +485,7 @@ spec =
"Variable \"$dog\" cannot be non-input type \"Dog\"."
, locations = [AST.Location 2 34]
}
- in validate queryString `shouldBe` [expected]
+ in validate queryString `shouldContain` [expected]
it "rejects undefined variables" $
let queryString = [r|
@@ -808,3 +808,35 @@ spec =
, locations = [AST.Location 4 19]
}
in validate queryString `shouldBe` [expected]
+
+ it "wrongly typed variable arguments" $
+ let queryString = [r|
+ query catCommandArgQuery($catCommandArg: CatCommand) {
+ cat {
+ doesKnowCommand(catCommand: $catCommandArg)
+ }
+ }
+ |]
+ expected = Error
+ { message =
+ "Variable \"$catCommandArg\" of type \"CatCommand\" \
+ \used in position expecting type \"!CatCommand\"."
+ , locations = [AST.Location 2 40]
+ }
+ in validate queryString `shouldBe` [expected]
+
+ it "wrongly typed variable arguments" $
+ let queryString = [r|
+ query intCannotGoIntoBoolean($intArg: Int) {
+ dog {
+ isHousetrained(atOtherHomes: $intArg)
+ }
+ }
+ |]
+ expected = Error
+ { message =
+ "Variable \"$intArg\" of type \"Int\" used in position \
+ \expecting type \"Boolean\"."
+ , locations = [AST.Location 2 44]
+ }
+ in validate queryString `shouldBe` [expected]