From 6daae8a5219f62de98b4a65788e436fb1eac8cba Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Fri, 2 Oct 2020 06:31:38 +0200 Subject: [PATCH] Validate directives are in valid locations --- src/Language/GraphQL/AST/DirectiveLocation.hs | 36 +++- src/Language/GraphQL/AST/Document.hs | 1 + src/Language/GraphQL/AST/Encoder.hs | 1 + src/Language/GraphQL/Validate.hs | 174 +++++++++++++----- src/Language/GraphQL/Validate/Rules.hs | 32 +++- src/Language/GraphQL/Validate/Validation.hs | 3 +- tests/Language/GraphQL/ValidateSpec.hs | 14 ++ 7 files changed, 209 insertions(+), 52 deletions(-) diff --git a/src/Language/GraphQL/AST/DirectiveLocation.hs b/src/Language/GraphQL/AST/DirectiveLocation.hs index c38c9ff..511225f 100644 --- a/src/Language/GraphQL/AST/DirectiveLocation.hs +++ b/src/Language/GraphQL/AST/DirectiveLocation.hs @@ -2,6 +2,8 @@ v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. -} +{-# LANGUAGE Safe #-} + -- | Various parts of a GraphQL document can be annotated with directives. -- This module describes locations in a document where directives can appear. module Language.GraphQL.AST.DirectiveLocation @@ -16,7 +18,13 @@ module Language.GraphQL.AST.DirectiveLocation data DirectiveLocation = ExecutableDirectiveLocation ExecutableDirectiveLocation | TypeSystemDirectiveLocation TypeSystemDirectiveLocation - deriving (Eq, Show) + deriving Eq + +instance Show DirectiveLocation where + show (ExecutableDirectiveLocation directiveLocation) = + show directiveLocation + show (TypeSystemDirectiveLocation directiveLocation) = + show directiveLocation -- | Where directives can appear in an executable definition, like a query. data ExecutableDirectiveLocation @@ -27,7 +35,16 @@ data ExecutableDirectiveLocation | FragmentDefinition | FragmentSpread | InlineFragment - deriving (Eq, Show) + deriving Eq + +instance Show ExecutableDirectiveLocation where + show Query = "QUERY" + show Mutation = "MUTATION" + show Subscription = "SUBSCRIPTION" + show Field = "FIELD" + show FragmentDefinition = "FRAGMENT_DEFINITION" + show FragmentSpread = "FRAGMENT_SPREAD" + show InlineFragment = "INLINE_FRAGMENT" -- | Where directives can appear in a type system definition. data TypeSystemDirectiveLocation @@ -42,4 +59,17 @@ data TypeSystemDirectiveLocation | EnumValue | InputObject | InputFieldDefinition - deriving (Eq, Show) + deriving Eq + +instance Show TypeSystemDirectiveLocation where + show Schema = "SCHEMA" + show Scalar = "SCALAR" + show Object = "OBJECT" + show FieldDefinition = "FIELD_DEFINITION" + show ArgumentDefinition = "ARGUMENT_DEFINITION" + show Interface = "INTERFACE" + show Union = "UNION" + show Enum = "ENUM" + show EnumValue = "ENUM_VALUE" + show InputObject = "INPUT_OBJECT" + show InputFieldDefinition = "INPUT_FIELD_DEFINITION" diff --git a/src/Language/GraphQL/AST/Document.hs b/src/Language/GraphQL/AST/Document.hs index 0b118af..489a242 100644 --- a/src/Language/GraphQL/AST/Document.hs +++ b/src/Language/GraphQL/AST/Document.hs @@ -1,5 +1,6 @@ {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE Safe #-} -- | This module defines an abstract syntax tree for the @GraphQL@ language. It -- follows closely the structure given in the specification. Please refer to diff --git a/src/Language/GraphQL/AST/Encoder.hs b/src/Language/GraphQL/AST/Encoder.hs index dd464c2..51a801e 100644 --- a/src/Language/GraphQL/AST/Encoder.hs +++ b/src/Language/GraphQL/AST/Encoder.hs @@ -1,6 +1,7 @@ {-# LANGUAGE ExplicitForAll #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE LambdaCase #-} +{-# LANGUAGE Safe #-} -- | This module defines a minifier and a printer for the @GraphQL@ language. module Language.GraphQL.AST.Encoder diff --git a/src/Language/GraphQL/Validate.hs b/src/Language/GraphQL/Validate.hs index b4ac29e..5acb26a 100644 --- a/src/Language/GraphQL/Validate.hs +++ b/src/Language/GraphQL/Validate.hs @@ -132,38 +132,104 @@ typeSystemExtension context rule = \case typeExtension :: forall m. Validation m -> ApplyRule m Full.TypeExtension typeExtension context rule = \case - Full.ScalarTypeExtension _ directives' -> directives context rule directives' + Full.ScalarTypeExtension _ directives' -> + directives context rule scalarLocation directives' Full.ObjectTypeFieldsDefinitionExtension _ _ directives' fields - -> directives context rule directives' + -> directives context rule objectLocation directives' >< foldMap (fieldDefinition context rule) fields Full.ObjectTypeDirectivesExtension _ _ directives' -> - directives context rule directives' + directives context rule objectLocation directives' Full.ObjectTypeImplementsInterfacesExtension _ _ -> mempty Full.InterfaceTypeFieldsDefinitionExtension _ directives' fields - -> directives context rule directives' + -> directives context rule interfaceLocation directives' >< foldMap (fieldDefinition context rule) fields Full.InterfaceTypeDirectivesExtension _ directives' -> - directives context rule directives' + directives context rule interfaceLocation directives' Full.UnionTypeUnionMemberTypesExtension _ directives' _ -> - directives context rule directives' + directives context rule unionLocation directives' Full.UnionTypeDirectivesExtension _ directives' -> - directives context rule directives' + directives context rule unionLocation directives' Full.EnumTypeEnumValuesDefinitionExtension _ directives' values - -> directives context rule directives' + -> directives context rule enumLocation directives' >< foldMap (enumValueDefinition context rule) values Full.EnumTypeDirectivesExtension _ directives' -> - directives context rule directives' + directives context rule enumLocation directives' Full.InputObjectTypeInputFieldsDefinitionExtension _ directives' fields - -> directives context rule directives' - >< foldMap (inputValueDefinition context rule) fields + -> directives context rule inputObjectLocation directives' + >< foldMap forEachInputFieldDefinition fields Full.InputObjectTypeDirectivesExtension _ directives' -> - directives context rule directives' + directives context rule inputObjectLocation directives' + where + forEachInputFieldDefinition = + inputValueDefinition context rule inputFieldDefinitionLocation schemaExtension :: forall m. Validation m -> ApplyRule m Full.SchemaExtension schemaExtension context rule = \case Full.SchemaOperationExtension directives' _ -> - directives context rule directives' - Full.SchemaDirectivesExtension directives' -> directives context rule directives' + directives context rule schemaLocation directives' + Full.SchemaDirectivesExtension directives' -> + directives context rule schemaLocation directives' + +schemaLocation :: DirectiveLocation +schemaLocation = TypeSystemDirectiveLocation DirectiveLocation.Schema + +interfaceLocation :: DirectiveLocation +interfaceLocation = TypeSystemDirectiveLocation DirectiveLocation.Interface + +objectLocation :: DirectiveLocation +objectLocation = TypeSystemDirectiveLocation DirectiveLocation.Object + +unionLocation :: DirectiveLocation +unionLocation = TypeSystemDirectiveLocation DirectiveLocation.Union + +enumLocation :: DirectiveLocation +enumLocation = TypeSystemDirectiveLocation DirectiveLocation.Enum + +inputObjectLocation :: DirectiveLocation +inputObjectLocation = TypeSystemDirectiveLocation DirectiveLocation.InputObject + +scalarLocation :: DirectiveLocation +scalarLocation = TypeSystemDirectiveLocation DirectiveLocation.Scalar + +enumValueLocation :: DirectiveLocation +enumValueLocation = TypeSystemDirectiveLocation DirectiveLocation.EnumValue + +fieldDefinitionLocation :: DirectiveLocation +fieldDefinitionLocation = + TypeSystemDirectiveLocation DirectiveLocation.FieldDefinition + +inputFieldDefinitionLocation :: DirectiveLocation +inputFieldDefinitionLocation = + TypeSystemDirectiveLocation DirectiveLocation.InputFieldDefinition + +argumentDefinitionLocation :: DirectiveLocation +argumentDefinitionLocation = + TypeSystemDirectiveLocation DirectiveLocation.ArgumentDefinition + +queryLocation :: DirectiveLocation +queryLocation = ExecutableDirectiveLocation DirectiveLocation.Query + +mutationLocation :: DirectiveLocation +mutationLocation = ExecutableDirectiveLocation DirectiveLocation.Mutation + +subscriptionLocation :: DirectiveLocation +subscriptionLocation = + ExecutableDirectiveLocation DirectiveLocation.Subscription + +fieldLocation :: DirectiveLocation +fieldLocation = ExecutableDirectiveLocation DirectiveLocation.Field + +fragmentDefinitionLocation :: DirectiveLocation +fragmentDefinitionLocation = + ExecutableDirectiveLocation DirectiveLocation.FragmentDefinition + +fragmentSpreadLocation :: DirectiveLocation +fragmentSpreadLocation = + ExecutableDirectiveLocation DirectiveLocation.FragmentSpread + +inlineFragmentLocation :: DirectiveLocation +inlineFragmentLocation = + ExecutableDirectiveLocation DirectiveLocation.InlineFragment executableDefinition :: forall m . Validation.Rule m @@ -179,7 +245,8 @@ typeSystemDefinition :: forall m . Validation m -> ApplyRule m Full.TypeSystemDefinition typeSystemDefinition context rule = \case - Full.SchemaDefinition directives' _ -> directives context rule directives' + Full.SchemaDefinition directives' _ -> + directives context rule schemaLocation directives' Full.TypeDefinition typeDefinition' -> typeDefinition context rule typeDefinition' Full.DirectiveDefinition _ _ arguments' _ -> @@ -188,44 +255,54 @@ typeSystemDefinition context rule = \case typeDefinition :: forall m. Validation m -> ApplyRule m Full.TypeDefinition typeDefinition context rule = \case Full.ScalarTypeDefinition _ _ directives' -> - directives context rule directives' + directives context rule scalarLocation directives' Full.ObjectTypeDefinition _ _ _ directives' fields - -> directives context rule directives' + -> directives context rule objectLocation directives' >< foldMap (fieldDefinition context rule) fields Full.InterfaceTypeDefinition _ _ directives' fields - -> directives context rule directives' + -> directives context rule interfaceLocation directives' >< foldMap (fieldDefinition context rule) fields Full.UnionTypeDefinition _ _ directives' _ -> - directives context rule directives' + directives context rule unionLocation directives' Full.EnumTypeDefinition _ _ directives' values - -> directives context rule directives' + -> directives context rule enumLocation directives' >< foldMap (enumValueDefinition context rule) values Full.InputObjectTypeDefinition _ _ directives' fields - -> directives context rule directives' - <> foldMap (inputValueDefinition context rule) fields + -> directives context rule inputObjectLocation directives' + <> foldMap forEachInputFieldDefinition fields + where + forEachInputFieldDefinition = + inputValueDefinition context rule inputFieldDefinitionLocation enumValueDefinition :: forall m . Validation m -> ApplyRule m Full.EnumValueDefinition enumValueDefinition context rule (Full.EnumValueDefinition _ _ directives') = - directives context rule directives' + directives context rule enumValueLocation directives' fieldDefinition :: forall m. Validation m -> ApplyRule m Full.FieldDefinition fieldDefinition context rule (Full.FieldDefinition _ _ arguments' _ directives') - = directives context rule directives' + = directives context rule fieldDefinitionLocation directives' >< argumentsDefinition context rule arguments' argumentsDefinition :: forall m . Validation m -> ApplyRule m Full.ArgumentsDefinition argumentsDefinition context rule (Full.ArgumentsDefinition definitions) = - foldMap (inputValueDefinition context rule) definitions + foldMap forEachArgument definitions + where + forEachArgument = + inputValueDefinition context rule argumentDefinitionLocation inputValueDefinition :: forall m . Validation m - -> ApplyRule m Full.InputValueDefinition -inputValueDefinition context rule (Full.InputValueDefinition _ _ _ _ directives') = - directives context rule directives' + -> Validation.Rule m + -> DirectiveLocation + -> Full.InputValueDefinition + -> Seq (Validation.RuleT m) +inputValueDefinition context rule directiveLocation definition' = + let Full.InputValueDefinition _ _ _ _ directives' = definition' + in directives context rule directiveLocation directives' operationDefinition :: forall m . Validation.Rule m @@ -239,18 +316,22 @@ operationDefinition rule context operation , Full.OperationDefinition _ _ variables _ _ _ <- operation = foldMap (variableDefinition context rule) variables |> variablesRule variables | Full.SelectionSet selections _ <- operation = - selectionSet context types' rule (getRootType Full.Query) selections - | Full.OperationDefinition operationType _ _ directives' selections _ <- operation - = selectionSet context types' rule (getRootType operationType) selections - >< directives context rule directives' + selectionSet context types' rule queryRoot selections + | Full.OperationDefinition Full.Query _ _ directives' selections _ <- operation + = selectionSet context types' rule queryRoot selections + >< directives context rule queryLocation directives' + | Full.OperationDefinition Full.Mutation _ _ directives' selections _ <- operation = + let root = Out.NamedObjectType <$> Schema.mutation schema' + in selectionSet context types' rule root selections + >< directives context rule mutationLocation directives' + | Full.OperationDefinition Full.Subscription _ _ directives' selections _ <- operation = + let root = Out.NamedObjectType <$> Schema.subscription schema' + in selectionSet context types' rule root selections + >< directives context rule subscriptionLocation directives' where + schema' = Validation.schema context + queryRoot = Just $ Out.NamedObjectType $ Schema.query schema' types' = Validation.types context - getRootType Full.Query = - Just $ Out.NamedObjectType $ Schema.query $ Validation.schema context - getRootType Full.Mutation = - Out.NamedObjectType <$> Schema.mutation (Validation.schema context) - getRootType Full.Subscription = - Out.NamedObjectType <$> Schema.subscription (Validation.schema context) typeToOut :: forall m. Schema.Type m -> Maybe (Out.Type m) typeToOut (Schema.ObjectType objectType) = @@ -320,7 +401,7 @@ fragmentDefinition rule context definition' types' = Validation.types context applyToChildren typeCondition directives' selections = selectionSet context types' rule (lookupType' typeCondition) selections - >< directives context rule directives' + >< directives context rule fragmentDefinitionLocation directives' lookupType' = flip lookupType types' lookupType :: forall m @@ -367,7 +448,7 @@ field context types' rule objectType field' = go field' typeField = objectType >>= lookupTypeField fieldName argumentTypes = maybe mempty typeFieldArguments typeField in selectionSet context types' rule (typeFieldType <$> typeField) selections - >< directives context rule directives' + >< directives context rule fieldLocation directives' >< arguments rule argumentTypes arguments' arguments :: forall m @@ -424,7 +505,7 @@ inlineFragment context types' rule objectType inlineFragment' = refineTarget Nothing = objectType applyToChildren objectType' directives' selections = selectionSet context types' rule objectType' selections - >< directives context rule directives' + >< directives context rule inlineFragmentLocation directives' fragmentSpread :: forall m. Validation m -> ApplyRule m Full.FragmentSpread fragmentSpread context rule fragmentSpread'@(Full.FragmentSpread _ directives' _) @@ -432,15 +513,18 @@ fragmentSpread context rule fragmentSpread'@(Full.FragmentSpread _ directives' _ applyToChildren |> fragmentRule fragmentSpread' | otherwise = applyToChildren where - applyToChildren = directives context rule directives' + applyToChildren = directives context rule fragmentSpreadLocation directives' directives :: Traversable t => forall m . Validation m - -> ApplyRule m (t Full.Directive) -directives context rule directives' + -> Validation.Rule m + -> DirectiveLocation + -> t Full.Directive + -> Seq (Validation.RuleT m) +directives context rule directiveLocation directives' | Validation.DirectivesRule directivesRule <- rule = - applyToChildren |> directivesRule directiveList + applyToChildren |> directivesRule directiveLocation directiveList | otherwise = applyToChildren where directiveList = toList directives' diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 7cfa712..6c35f70 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -10,7 +10,8 @@ -- | This module contains default rules defined in the GraphQL specification. module Language.GraphQL.Validate.Rules - ( executableDefinitionsRule + ( directivesInValidLocationsRule + , executableDefinitionsRule , fieldsOnCorrectTypeRule , fragmentsOnCompositeTypesRule , fragmentSpreadTargetDefinedRule @@ -90,6 +91,7 @@ specifiedRules = , uniqueInputFieldNamesRule -- Directives. , knownDirectiveNamesRule + , directivesInValidLocationsRule , uniqueDirectiveNamesRule -- Variables. , uniqueVariableNamesRule @@ -514,7 +516,7 @@ uniqueArgumentNamesRule = ArgumentsRule fieldRule directiveRule -- of each directive is allowed per location. uniqueDirectiveNamesRule :: forall m. Rule m uniqueDirectiveNamesRule = DirectivesRule - $ lift . filterDuplicates extract "directive" + $ const $ lift . filterDuplicates extract "directive" where extract (Directive directiveName _ location') = (directiveName, location') @@ -818,7 +820,7 @@ knownArgumentNamesRule = ArgumentsRule fieldRule directiveRule -- | GraphQL servers define what directives they support. For each usage of a -- directive, the directive must be available on that server. knownDirectiveNamesRule :: Rule m -knownDirectiveNamesRule = DirectivesRule $ \directives' -> do +knownDirectiveNamesRule = DirectivesRule $ const $ \directives' -> do definitions' <- asks directives let directiveSet = HashSet.fromList $ fmap directiveName directives' let definitionSet = HashSet.fromList $ HashMap.keys definitions' @@ -867,3 +869,27 @@ knownInputFieldNamesRule = ValueRule go constGo , Text.unpack typeName , "\"." ] + +-- | GraphQL servers define what directives they support and where they support +-- them. For each usage of a directive, the directive must be used in a location +-- that the server has declared support for. +directivesInValidLocationsRule :: Rule m +directivesInValidLocationsRule = DirectivesRule directivesRule + where + directivesRule directiveLocation directives' = do + Directive directiveName _ location <- lift $ Seq.fromList directives' + maybeDefinition <- asks $ HashMap.lookup directiveName . directives + case maybeDefinition of + Just (Schema.Directive _ allowedLocations _) + | directiveLocation `notElem` allowedLocations -> pure $ Error + { message = errorMessage directiveName directiveLocation + , locations = [location] + } + _ -> lift mempty + errorMessage directiveName directiveLocation = concat + [ "Directive \"@" + , Text.unpack directiveName + , "\" may not be used on " + , show directiveLocation + , "." + ] diff --git a/src/Language/GraphQL/Validate/Validation.hs b/src/Language/GraphQL/Validate/Validation.hs index 0e9f1a8..32a454e 100644 --- a/src/Language/GraphQL/Validate/Validation.hs +++ b/src/Language/GraphQL/Validate/Validation.hs @@ -13,6 +13,7 @@ module Language.GraphQL.Validate.Validation import Control.Monad.Trans.Reader (ReaderT) import Data.HashMap.Strict (HashMap) import Data.Sequence (Seq) +import Language.GraphQL.AST.DirectiveLocation (DirectiveLocation(..)) import Language.GraphQL.AST.Document import qualified Language.GraphQL.Type.In as In import qualified Language.GraphQL.Type.Out as Out @@ -45,7 +46,7 @@ data Rule m | FragmentSpreadRule (FragmentSpread -> RuleT m) | FieldRule (Maybe (Out.Type m) -> Field -> RuleT m) | ArgumentsRule (Maybe (Out.Type m) -> Field -> RuleT m) (Directive -> RuleT m) - | DirectivesRule ([Directive] -> RuleT m) + | DirectivesRule (DirectiveLocation -> [Directive] -> RuleT m) | VariablesRule ([VariableDefinition] -> RuleT m) | ValueRule (Maybe In.Type -> Value -> RuleT m) (Maybe In.Type -> ConstValue -> RuleT m) diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs index b4433ca..6bf0a04 100644 --- a/tests/Language/GraphQL/ValidateSpec.hs +++ b/tests/Language/GraphQL/ValidateSpec.hs @@ -606,3 +606,17 @@ spec = , locations = [AST.Location 3 36] } in validate queryString `shouldBe` [expected] + + it "rejects directives in invalid locations" $ + let queryString = [r| + query @skip(if: $foo) { + dog { + name + } + } + |] + expected = Error + { message = "Directive \"@skip\" may not be used on QUERY." + , locations = [AST.Location 2 21] + } + in validate queryString `shouldBe` [expected]