From 22abf7ca58091a521de037bd9b22689e7309b8ba Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Sat, 26 Dec 2020 06:31:56 +0100 Subject: [PATCH] Validate variable usages are allowed in arguments --- CHANGELOG.md | 1 + src/Language/GraphQL/Validate/Rules.hs | 211 ++++++++++++++++++++++++- stack.yaml | 2 +- tests/Language/GraphQL/ValidateSpec.hs | 34 +++- 4 files changed, 238 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d35a35a..817357e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to - `Validate.Rules`: - `overlappingFieldsCanBeMergedRule` - `possibleFragmentSpreadsRule` + - `variablesInAllowedPositionRule` - `Type.Schema.implementations` contains a map from interfaces and objects to interfaces they implement. - Show instances for GraphQL type definitions in the `Type` modules. diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 56c839e..a159983 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -37,6 +37,7 @@ module Language.GraphQL.Validate.Rules , uniqueInputFieldNamesRule , uniqueOperationNamesRule , uniqueVariableNamesRule + , variablesInAllowedPositionRule , variablesAreInputTypesRule ) where @@ -45,13 +46,13 @@ import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) import Data.Bifunctor (first) -import Data.Foldable (find, foldl', toList) +import Data.Foldable (find, fold, foldl', toList) import qualified Data.HashMap.Strict as HashMap import Data.HashMap.Strict (HashMap) import Data.HashSet (HashSet) import qualified Data.HashSet as HashSet import Data.List (groupBy, sortBy, sortOn) -import Data.Maybe (fromMaybe, isNothing, mapMaybe) +import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) import Data.List.NonEmpty (NonEmpty(..)) import Data.Ord (comparing) import Data.Sequence (Seq(..), (|>)) @@ -108,6 +109,7 @@ specifiedRules = , variablesAreInputTypesRule , noUndefinedVariablesRule , noUnusedVariablesRule + , variablesInAllowedPositionRule ] -- | Definition must be OperationDefinition or FragmentDefinition. @@ -651,12 +653,9 @@ variableUsageDifference difference errorMessage = OperationDefinitionRule $ \cas lift $ lift $ mapArguments arguments <> mapDirectives directives' variableFilter (Full.FragmentSpreadSelection spread) | Full.FragmentSpread fragmentName _ _ <- spread = do - definitions <- lift $ asks ast - visited <- gets (HashSet.member fragmentName) - modify (HashSet.insert fragmentName) - case find (isSpreadTarget fragmentName) definitions of - Just (viewFragment -> Just fragmentDefinition) - | not visited -> diveIntoSpread fragmentDefinition + nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName + case nonVisitedFragmentDefinition of + Just fragmentDefinition -> diveIntoSpread fragmentDefinition _ -> lift $ lift mempty diveIntoSpread (Full.FragmentDefinition _ _ directives' selections _) = filterSelections' selections @@ -1286,3 +1285,199 @@ findSpreadTarget fragmentName = do let Full.FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition in Just typeCondition extractTypeCondition _ = Nothing + +visitFragmentDefinition :: forall m + . Text + -> ValidationState m (Maybe Full.FragmentDefinition) +visitFragmentDefinition fragmentName = do + definitions <- lift $ asks ast + visited <- gets (HashSet.member fragmentName) + modify (HashSet.insert fragmentName) + case find (isSpreadTarget fragmentName) definitions of + Just (viewFragment -> Just fragmentDefinition) + | not visited -> pure $ Just fragmentDefinition + _ -> pure Nothing + +-- | Variable usages must be compatible with the arguments they are passed to. +-- +-- Validation failures occur when variables are used in the context of types +-- that are complete mismatches, or if a nullable type in a variable is passed +-- to a non‐null argument type. +variablesInAllowedPositionRule :: forall m. Rule m +variablesInAllowedPositionRule = OperationDefinitionRule $ \case + Full.OperationDefinition operationType _ variables _ selectionSet _ -> do + schema' <- asks schema + let root = go variables (toList selectionSet) . Type.CompositeObjectType + case operationType of + Full.Query -> root $ Schema.query schema' + Full.Mutation + | Just objectType <- Schema.mutation schema' -> root objectType + Full.Subscription + | Just objectType <- Schema.mutation schema' -> root objectType + _ -> lift mempty + _ -> lift mempty + where + go variables selections selectionType = mapReaderT (foldr (<>) Seq.empty) + $ flip evalStateT HashSet.empty + $ visitSelectionSet variables selectionType + $ toList selections + visitSelectionSet :: Foldable t + => [Full.VariableDefinition] + -> Type.CompositeType m + -> t Full.Selection + -> ValidationState m (Seq Error) + visitSelectionSet variables selectionType selections = + foldM (evaluateSelection variables selectionType) mempty selections + evaluateSelection :: [Full.VariableDefinition] + -> Type.CompositeType m + -> Seq Error + -> Full.Selection + -> ValidationState m (Seq Error) + evaluateSelection variables selectionType accumulator selection + | Full.FragmentSpreadSelection spread <- selection + , Full.FragmentSpread fragmentName _ _ <- spread = do + types' <- lift $ asks $ Schema.types . schema + nonVisitedFragmentDefinition <- visitFragmentDefinition fragmentName + case nonVisitedFragmentDefinition of + Just fragmentDefinition + | Full.FragmentDefinition _ typeCondition _ _ _ <- fragmentDefinition + , Just spreadType <- Type.lookupTypeCondition typeCondition types' -> do + a <- spreadVariables variables spread + b <- diveIntoSpread variables spreadType fragmentDefinition + pure $ accumulator <> a <> b + _ -> lift $ lift mempty + | Full.FieldSelection fieldSelection <- selection + , Full.Field _ fieldName _ _ subselections _ <- fieldSelection = + case Type.lookupCompositeField fieldName selectionType of + Just (Out.Field _ typeField argumentTypes) -> do + a <- fieldVariables variables argumentTypes fieldSelection + case Type.outToComposite typeField of + Just newParentType -> do + b <- foldM (evaluateSelection variables newParentType) accumulator subselections + pure $ accumulator <> a <> b + Nothing -> pure $ accumulator <> a + Nothing -> pure accumulator + | Full.InlineFragmentSelection inlineSelection <- selection + , Full.InlineFragment typeCondition _ subselections _ <- inlineSelection = do + types' <- lift $ asks $ Schema.types . schema + let inlineType = fromMaybe selectionType + $ typeCondition + >>= flip Type.lookupTypeCondition types' + a <- inlineVariables variables inlineSelection + b <- foldM (evaluateSelection variables inlineType) accumulator subselections + pure $ accumulator <> a <> b + inlineVariables variables inline + | Full.InlineFragment _ directives' _ _ <- inline = + mapDirectives variables directives' + fieldVariables :: [Full.VariableDefinition] + -> In.Arguments + -> Full.Field + -> ValidationState m (Seq Error) + fieldVariables variables argumentTypes fieldSelection = do + let Full.Field _ _ arguments directives' _ _ = fieldSelection + argumentErrors <- mapArguments variables argumentTypes arguments + directiveErrors <- mapDirectives variables directives' + pure $ argumentErrors <> directiveErrors + spreadVariables variables (Full.FragmentSpread _ directives' _) = + mapDirectives variables directives' + diveIntoSpread variables fieldType fragmentDefinition = do + let Full.FragmentDefinition _ _ directives' selections _ = + fragmentDefinition + selectionErrors <- visitSelectionSet variables fieldType selections + directiveErrors <- mapDirectives variables directives' + pure $ selectionErrors <> directiveErrors + findDirectiveVariables variables directive = do + let Full.Directive directiveName arguments _ = directive + directiveDefinitions <- lift $ asks $ Schema.directives . schema + case HashMap.lookup directiveName directiveDefinitions of + Just (Schema.Directive _ _ directiveArguments) -> + mapArguments variables directiveArguments arguments + Nothing -> pure mempty + mapArguments variables argumentTypes + = fmap (Seq.fromList . catMaybes) + . traverse (findArgumentVariables variables argumentTypes) + mapDirectives variables = fmap fold + <$> traverse (findDirectiveVariables variables) + findArgumentVariables variables argumentTypes argument + | Full.Argument argumentName argumentValue _ <- argument + , Full.Node{ node = Full.Variable variableName, ..} <- argumentValue + , Just expectedType <- HashMap.lookup argumentName argumentTypes + , findVariableDefinition' <- findVariableDefinition variableName + , Just variableDefinition <- find findVariableDefinition' variables = + isVariableUsageAllowed expectedType variableDefinition + | otherwise = pure Nothing + findVariableDefinition variableName variableDefinition = + let Full.VariableDefinition variableName' _ _ _ = variableDefinition + in variableName == variableName' + isVariableUsageAllowed (In.Argument _ locationType locationDefaultValue) variableDefinition@(Full.VariableDefinition _ variableType variableDefaultValue _) + | Full.TypeNonNull _ <- variableType = + typesCompatibleOrError variableDefinition locationType + | Just nullableLocationType <- unwrapInType locationType + , hasNonNullVariableDefaultValue' <- hasNonNullVariableDefaultValue variableDefaultValue + , hasLocationDefaultValue <- isJust locationDefaultValue = + if hasNonNullVariableDefaultValue' || hasLocationDefaultValue + then typesCompatibleOrError variableDefinition nullableLocationType + else pure $ Just $ makeError variableDefinition locationType + | otherwise = typesCompatibleOrError variableDefinition locationType + typesCompatibleOrError variableDefinition locationType + | Full.VariableDefinition _ variableType _ _ <- variableDefinition + , areTypesCompatible variableType locationType = pure Nothing + | otherwise = pure $ Just $ makeError variableDefinition locationType + areTypesCompatible (Full.TypeNonNull nonNullType) (unwrapInType -> Just nullableLocationType) = + case nonNullType of + Full.NonNullTypeNamed n -> + areTypesCompatible (Full.TypeNamed n) nullableLocationType + Full.NonNullTypeList n -> + areTypesCompatible (Full.TypeList n) nullableLocationType + areTypesCompatible _ (In.isNonNullType -> True) = False + areTypesCompatible (Full.TypeNonNull nonNullType) locationType + | Full.NonNullTypeNamed namedType <- nonNullType = + areTypesCompatible (Full.TypeNamed namedType) locationType + | Full.NonNullTypeList namedType <- nonNullType = + areTypesCompatible (Full.TypeList namedType) locationType + areTypesCompatible variableType locationType + | Full.TypeList itemVariableType <- variableType + , In.ListType itemLocationType <- locationType = + areTypesCompatible itemVariableType itemLocationType + | areIdentical variableType locationType = True + | otherwise = False + areIdentical (Full.TypeList typeList) (In.ListType itemLocationType) = + areIdentical typeList itemLocationType + areIdentical (Full.TypeNonNull nonNullType) locationType + | Full.NonNullTypeList nonNullList <- nonNullType + , In.NonNullListType itemLocationType <- locationType = + areIdentical nonNullList itemLocationType + | Full.NonNullTypeNamed _ <- nonNullType + , In.ListBaseType _ <- locationType = False + | Full.NonNullTypeNamed nonNullList <- nonNullType + , In.isNonNullType locationType = + nonNullList == inputTypeName locationType + areIdentical (Full.TypeNamed _) (In.ListBaseType _) = False + areIdentical (Full.TypeNamed typeNamed) locationType + | not $ In.isNonNullType locationType = + typeNamed == inputTypeName locationType + areIdentical _ _ = False + hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False + hasNonNullVariableDefaultValue Nothing = False + hasNonNullVariableDefaultValue _ = True + unwrapInType (In.NonNullScalarType nonNullType) = + Just $ In.NamedScalarType nonNullType + unwrapInType (In.NonNullEnumType nonNullType) = + Just $ In.NamedEnumType nonNullType + unwrapInType (In.NonNullInputObjectType nonNullType) = + Just $ In.NamedInputObjectType nonNullType + unwrapInType (In.NonNullListType nonNullType) = + Just $ In.ListType nonNullType + unwrapInType _ = Nothing + makeError (Full.VariableDefinition variableName variableType _ location') expectedType = Error + { message = concat + [ "Variable \"$" + , Text.unpack variableName + , "\" of type \"" + , show variableType + , "\" used in position expecting type \"" + , show expectedType + , "\"." + ] + , locations = [location'] + } diff --git a/stack.yaml b/stack.yaml index 9ee8fd3..be869c7 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -resolver: lts-16.26 +resolver: lts-16.27 packages: - . diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs index 4063b57..d340d4e 100644 --- a/tests/Language/GraphQL/ValidateSpec.hs +++ b/tests/Language/GraphQL/ValidateSpec.hs @@ -485,7 +485,7 @@ spec = "Variable \"$dog\" cannot be non-input type \"Dog\"." , locations = [AST.Location 2 34] } - in validate queryString `shouldBe` [expected] + in validate queryString `shouldContain` [expected] it "rejects undefined variables" $ let queryString = [r| @@ -808,3 +808,35 @@ spec = , locations = [AST.Location 4 19] } in validate queryString `shouldBe` [expected] + + it "wrongly typed variable arguments" $ + let queryString = [r| + query catCommandArgQuery($catCommandArg: CatCommand) { + cat { + doesKnowCommand(catCommand: $catCommandArg) + } + } + |] + expected = Error + { message = + "Variable \"$catCommandArg\" of type \"CatCommand\" \ + \used in position expecting type \"!CatCommand\"." + , locations = [AST.Location 2 40] + } + in validate queryString `shouldBe` [expected] + + it "wrongly typed variable arguments" $ + let queryString = [r| + query intCannotGoIntoBoolean($intArg: Int) { + dog { + isHousetrained(atOtherHomes: $intArg) + } + } + |] + expected = Error + { message = + "Variable \"$intArg\" of type \"Int\" used in position \ + \expecting type \"Boolean\"." + , locations = [AST.Location 2 44] + } + in validate queryString `shouldBe` [expected]