From 73555332681a3702db5e277f21a53c628c3a524f Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Tue, 25 Aug 2020 21:03:42 +0200 Subject: [PATCH] Validate single root field in subscriptions --- CHANGELOG.md | 8 ++ graphql.cabal | 3 +- src/Language/GraphQL/AST/Document.hs | 7 +- src/Language/GraphQL/AST/Encoder.hs | 14 ++-- src/Language/GraphQL/AST/Parser.hs | 45 +++++----- src/Language/GraphQL/Execute/Execution.hs | 24 ------ src/Language/GraphQL/Execute/Transform.hs | 36 +++----- src/Language/GraphQL/Type/Internal.hs | 39 +++++++++ src/Language/GraphQL/Validate.hs | 72 ++++++++++------ src/Language/GraphQL/Validate/Rules.hs | 93 ++++++++++++++++++--- src/Language/GraphQL/Validate/Validation.hs | 34 ++++++++ tests/Language/GraphQL/AST/EncoderSpec.hs | 4 +- tests/Language/GraphQL/ValidateSpec.hs | 42 +++++++++- 13 files changed, 301 insertions(+), 120 deletions(-) create mode 100644 src/Language/GraphQL/Validate/Validation.hs diff --git a/CHANGELOG.md b/CHANGELOG.md index 129621e..d1b5216 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,14 @@ and this project adheres to ## [Unreleased] ## Changed - `Test.Hspec.GraphQL.*`: replace `IO` in the resolver with any `MonadCatch`. +- The `Location` argument of `AST.Document.Definition.ExecutableDefinition` was + moved to `OperationDefinition` and `FragmentDefinition` since these are the + actual elements that have a location in the document. +- `Validate.Rules` get the whole validation context (AST and schema). + +## Added +- `Validate.Validation` contains data structures and functions used by the + validator and concretet rules. ## [0.9.0.0] - 2020-07-24 ## Fixed diff --git a/graphql.cabal b/graphql.cabal index 73e71e7..46c3d6c 100644 --- a/graphql.cabal +++ b/graphql.cabal @@ -4,7 +4,7 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: 1d8c32c00a882ccd1fefc4c083d5fe4e83a1825fbf8e0dcfd551ff2c8cd2dda0 +-- hash: 59e2949d07cb5e678b493b77771db1bd64947de480f3da93ca07b3f2458cc495 name: graphql version: 0.9.0.0 @@ -50,6 +50,7 @@ library Language.GraphQL.Type.Out Language.GraphQL.Type.Schema Language.GraphQL.Validate + Language.GraphQL.Validate.Validation Test.Hspec.GraphQL other-modules: Language.GraphQL.Execute.Execution diff --git a/src/Language/GraphQL/AST/Document.hs b/src/Language/GraphQL/AST/Document.hs index 3394bfa..f60ddda 100644 --- a/src/Language/GraphQL/AST/Document.hs +++ b/src/Language/GraphQL/AST/Document.hs @@ -69,7 +69,7 @@ type Document = NonEmpty Definition -- | All kinds of definitions that can occur in a GraphQL document. data Definition - = ExecutableDefinition ExecutableDefinition Location + = ExecutableDefinition ExecutableDefinition | TypeSystemDefinition TypeSystemDefinition Location | TypeSystemExtension TypeSystemExtension Location deriving (Eq, Show) @@ -84,13 +84,14 @@ data ExecutableDefinition -- | Operation definition. data OperationDefinition - = SelectionSet SelectionSet + = SelectionSet SelectionSet Location | OperationDefinition OperationType (Maybe Name) [VariableDefinition] [Directive] SelectionSet + Location deriving (Eq, Show) -- | GraphQL has 3 operation types: @@ -195,7 +196,7 @@ type Alias = Name -- | Fragment definition. data FragmentDefinition - = FragmentDefinition Name TypeCondition [Directive] SelectionSet + = FragmentDefinition Name TypeCondition [Directive] SelectionSet Location deriving (Eq, Show) -- | Type condition. diff --git a/src/Language/GraphQL/AST/Encoder.hs b/src/Language/GraphQL/AST/Encoder.hs index a0dac5b..ba89d36 100644 --- a/src/Language/GraphQL/AST/Encoder.hs +++ b/src/Language/GraphQL/AST/Encoder.hs @@ -50,8 +50,8 @@ document formatter defs | Minified <-formatter = Lazy.Text.snoc (mconcat encodeDocument) '\n' where encodeDocument = foldr executableDefinition [] defs - executableDefinition (ExecutableDefinition x _) acc = - definition formatter x : acc + executableDefinition (ExecutableDefinition executableDefinition') acc = + definition formatter executableDefinition' : acc executableDefinition _ acc = acc -- | Converts a t'ExecutableDefinition' into a string. @@ -68,12 +68,12 @@ definition formatter x -- | Converts a 'OperationDefinition into a string. operationDefinition :: Formatter -> OperationDefinition -> Lazy.Text operationDefinition formatter = \case - SelectionSet sels -> selectionSet formatter sels - OperationDefinition Query name vars dirs sels -> + SelectionSet sels _ -> selectionSet formatter sels + OperationDefinition Query name vars dirs sels _ -> "query " <> node formatter name vars dirs sels - OperationDefinition Mutation name vars dirs sels -> + OperationDefinition Mutation name vars dirs sels _ -> "mutation " <> node formatter name vars dirs sels - OperationDefinition Subscription name vars dirs sels -> + OperationDefinition Subscription name vars dirs sels _ -> "subscription " <> node formatter name vars dirs sels -- | Converts a Query or Mutation into a string. @@ -190,7 +190,7 @@ inlineFragment formatter tc dirs sels = "... on " <> selectionSet formatter sels fragmentDefinition :: Formatter -> FragmentDefinition -> Lazy.Text -fragmentDefinition formatter (FragmentDefinition name tc dirs sels) +fragmentDefinition formatter (FragmentDefinition name tc dirs sels _) = "fragment " <> Lazy.Text.fromStrict name <> " on " <> Lazy.Text.fromStrict tc <> optempty (directives formatter) dirs diff --git a/src/Language/GraphQL/AST/Parser.hs b/src/Language/GraphQL/AST/Parser.hs index 687d8f5..7bc51cb 100644 --- a/src/Language/GraphQL/AST/Parser.hs +++ b/src/Language/GraphQL/AST/Parser.hs @@ -21,7 +21,8 @@ import Language.GraphQL.AST.DirectiveLocation import Language.GraphQL.AST.Document import Language.GraphQL.AST.Lexer import Text.Megaparsec - ( SourcePos(..) + ( MonadParsec(..) + , SourcePos(..) , getSourcePos , lookAhead , option @@ -37,15 +38,11 @@ document = unicodeBOM *> lexeme (NonEmpty.some definition) definition :: Parser Definition -definition = executableDefinition' +definition = ExecutableDefinition <$> executableDefinition <|> typeSystemDefinition' <|> typeSystemExtension' "Definition" where - executableDefinition' = do - location <- getLocation - definition' <- executableDefinition - pure $ ExecutableDefinition definition' location typeSystemDefinition' = do location <- getLocation definition' <- typeSystemDefinition @@ -349,16 +346,22 @@ operationTypeDefinition = OperationTypeDefinition "OperationTypeDefinition" operationDefinition :: Parser OperationDefinition -operationDefinition = SelectionSet <$> selectionSet +operationDefinition = shorthand <|> operationDefinition' "OperationDefinition" where - operationDefinition' - = OperationDefinition <$> operationType - <*> optional name - <*> variableDefinitions - <*> directives - <*> selectionSet + shorthand = do + location <- getLocation + selectionSet' <- selectionSet + pure $ SelectionSet selectionSet' location + operationDefinition' = do + location <- getLocation + operationType' <- operationType + operationName <- optional name + variableDefinitions' <- variableDefinitions + directives' <- directives + selectionSet' <- selectionSet + pure $ OperationDefinition operationType' operationName variableDefinitions' directives' selectionSet' location operationType :: Parser OperationType operationType = Query <$ symbol "query" @@ -412,13 +415,15 @@ inlineFragment = InlineFragment "InlineFragment" fragmentDefinition :: Parser FragmentDefinition -fragmentDefinition = FragmentDefinition - <$ symbol "fragment" - <*> name - <*> typeCondition - <*> directives - <*> selectionSet - "FragmentDefinition" +fragmentDefinition = label "FragmentDefinition" $ do + location <- getLocation + _ <- symbol "fragment" + fragmentName' <- name + typeCondition' <- typeCondition + directives' <- directives + selectionSet' <- selectionSet + pure $ FragmentDefinition + fragmentName' typeCondition' directives' selectionSet' location fragmentName :: Parser Name fragmentName = but (symbol "on") *> name "FragmentName" diff --git a/src/Language/GraphQL/Execute/Execution.hs b/src/Language/GraphQL/Execute/Execution.hs index d8d5b13..71a2baa 100644 --- a/src/Language/GraphQL/Execute/Execution.hs +++ b/src/Language/GraphQL/Execute/Execution.hs @@ -83,30 +83,6 @@ resolveAbstractType abstractType values' _ -> pure Nothing | otherwise = pure Nothing -doesFragmentTypeApply :: forall m - . CompositeType m - -> Out.ObjectType m - -> Bool -doesFragmentTypeApply (CompositeObjectType fragmentType) objectType = - fragmentType == objectType -doesFragmentTypeApply (CompositeInterfaceType fragmentType) objectType = - instanceOf objectType $ AbstractInterfaceType fragmentType -doesFragmentTypeApply (CompositeUnionType fragmentType) objectType = - instanceOf objectType $ AbstractUnionType fragmentType - -instanceOf :: forall m. Out.ObjectType m -> AbstractType m -> Bool -instanceOf objectType (AbstractInterfaceType interfaceType) = - let Out.ObjectType _ _ interfaces _ = objectType - in foldr go False interfaces - where - go objectInterfaceType@(Out.InterfaceType _ _ interfaces _) acc = - acc || foldr go (interfaceType == objectInterfaceType) interfaces -instanceOf objectType (AbstractUnionType unionType) = - let Out.UnionType _ _ members = unionType - in foldr go False members - where - go unionMemberType acc = acc || objectType == unionMemberType - executeField :: (MonadCatch m, Serialize a) => Out.Resolver m -> Type.Value diff --git a/src/Language/GraphQL/Execute/Transform.hs b/src/Language/GraphQL/Execute/Transform.hs index 76d1fe7..9c7ad0a 100644 --- a/src/Language/GraphQL/Execute/Transform.hs +++ b/src/Language/GraphQL/Execute/Transform.hs @@ -255,18 +255,18 @@ defragment ast = in (, fragmentTable) <$> maybe emptyDocument Right nonEmptyOperations where defragment' definition (operations, fragments') - | (Full.ExecutableDefinition executable _) <- definition + | (Full.ExecutableDefinition executable) <- definition , (Full.DefinitionOperation operation') <- executable = (transform operation' : operations, fragments') - | (Full.ExecutableDefinition executable _) <- definition + | (Full.ExecutableDefinition executable) <- definition , (Full.DefinitionFragment fragment) <- executable - , (Full.FragmentDefinition name _ _ _) <- fragment = + , (Full.FragmentDefinition name _ _ _ _) <- fragment = (operations, HashMap.insert name fragment fragments') defragment' _ acc = acc transform = \case - Full.OperationDefinition type' name variables directives' selections -> + Full.OperationDefinition type' name variables directives' selections _ -> OperationDefinition type' name variables directives' selections - Full.SelectionSet selectionSet -> + Full.SelectionSet selectionSet _ -> OperationDefinition Full.Query Nothing mempty mempty selectionSet -- * Operation @@ -324,8 +324,8 @@ selection (Full.InlineFragment type' directives' selections) = do case type' of Nothing -> pure $ Left fragmentSelectionSet Just typeName -> do - typeCondition' <- lookupTypeCondition typeName - case typeCondition' of + types' <- gets types + case lookupTypeCondition typeName types' of Just typeCondition -> pure $ selectionFragment typeCondition fragmentSelectionSet Nothing -> pure $ Left mempty @@ -364,29 +364,17 @@ collectFragments = do _ <- fragmentDefinition nextValue collectFragments -lookupTypeCondition :: Full.Name -> State (Replacement m) (Maybe (CompositeType m)) -lookupTypeCondition type' = do - types' <- gets types - case HashMap.lookup type' types' of - Just (ObjectType objectType) -> - lift $ pure $ Just $ CompositeObjectType objectType - Just (UnionType unionType) -> - lift $ pure $ Just $ CompositeUnionType unionType - Just (InterfaceType interfaceType) -> - lift $ pure $ Just $ CompositeInterfaceType interfaceType - _ -> lift $ pure Nothing - fragmentDefinition :: Full.FragmentDefinition -> State (Replacement m) (Maybe (Fragment m)) -fragmentDefinition (Full.FragmentDefinition name type' _ selections) = do +fragmentDefinition (Full.FragmentDefinition name type' _ selections _) = do modify deleteFragmentDefinition fragmentSelection <- appendSelection selections - compositeType <- lookupTypeCondition type' + types' <- gets types - case compositeType of - Just compositeType' -> do - let newValue = Fragment compositeType' fragmentSelection + case lookupTypeCondition type' types' of + Just compositeType -> do + let newValue = Fragment compositeType fragmentSelection modify $ insertFragment newValue lift $ pure $ Just newValue _ -> lift $ pure Nothing diff --git a/src/Language/GraphQL/Type/Internal.hs b/src/Language/GraphQL/Type/Internal.hs index 9121d13..6f25777 100644 --- a/src/Language/GraphQL/Type/Internal.hs +++ b/src/Language/GraphQL/Type/Internal.hs @@ -8,6 +8,9 @@ module Language.GraphQL.Type.Internal ( AbstractType(..) , CompositeType(..) , collectReferencedTypes + , doesFragmentTypeApply + , instanceOf + , lookupTypeCondition ) where import Data.HashMap.Strict (HashMap) @@ -89,3 +92,39 @@ collectReferencedTypes schema = polymorphicTraverser interfaces fields = flip (foldr visitFields) fields . flip (foldr traverseInterfaceType) interfaces + +doesFragmentTypeApply :: forall m + . CompositeType m + -> Out.ObjectType m + -> Bool +doesFragmentTypeApply (CompositeObjectType fragmentType) objectType = + fragmentType == objectType +doesFragmentTypeApply (CompositeInterfaceType fragmentType) objectType = + instanceOf objectType $ AbstractInterfaceType fragmentType +doesFragmentTypeApply (CompositeUnionType fragmentType) objectType = + instanceOf objectType $ AbstractUnionType fragmentType + +instanceOf :: forall m. Out.ObjectType m -> AbstractType m -> Bool +instanceOf objectType (AbstractInterfaceType interfaceType) = + let Out.ObjectType _ _ interfaces _ = objectType + in foldr go False interfaces + where + go objectInterfaceType@(Out.InterfaceType _ _ interfaces _) acc = + acc || foldr go (interfaceType == objectInterfaceType) interfaces +instanceOf objectType (AbstractUnionType unionType) = + let Out.UnionType _ _ members = unionType + in foldr go False members + where + go unionMemberType acc = acc || objectType == unionMemberType + +lookupTypeCondition :: forall m + . Name + -> HashMap Name (Type m) + -> Maybe (CompositeType m) +lookupTypeCondition type' types' = + case HashMap.lookup type' types' of + Just (ObjectType objectType) -> Just $ CompositeObjectType objectType + Just (UnionType unionType) -> Just $ CompositeUnionType unionType + Just (InterfaceType interfaceType) -> + Just $ CompositeInterfaceType interfaceType + _ -> Nothing diff --git a/src/Language/GraphQL/Validate.hs b/src/Language/GraphQL/Validate.hs index 5768615..95f7462 100644 --- a/src/Language/GraphQL/Validate.hs +++ b/src/Language/GraphQL/Validate.hs @@ -13,22 +13,19 @@ module Language.GraphQL.Validate , module Language.GraphQL.Validate.Rules ) where -import Control.Monad.Trans.Reader (Reader, asks, runReader) +import Control.Monad (foldM) +import Control.Monad.Trans.Reader (Reader, asks, mapReaderT, runReader) import Data.Foldable (foldrM) import Data.Sequence (Seq(..), (><), (|>)) import qualified Data.Sequence as Seq import Data.Text (Text) import Language.GraphQL.AST.Document -import Language.GraphQL.Type.Schema +import Language.GraphQL.Type.Internal +import Language.GraphQL.Type.Schema (Schema(..)) import Language.GraphQL.Validate.Rules +import Language.GraphQL.Validate.Validation -data Context m = Context - { ast :: Document - , schema :: Schema m - , rules :: [Rule] - } - -type ValidateT m = Reader (Context m) (Seq Error) +type ValidateT m = Reader (Validation m) (Seq Error) -- | If an error can be associated to a particular field in the GraphQL result, -- it must contain an entry with the key path that details the path of the @@ -48,36 +45,46 @@ data Error = Error -- | Validates a document and returns a list of found errors. If the returned -- list is empty, the document is valid. -document :: forall m. Schema m -> [Rule] -> Document -> Seq Error +document :: forall m. Schema m -> [Rule m] -> Document -> Seq Error document schema' rules' document' = runReader (foldrM go Seq.empty document') context where - context = Context + context = Validation { ast = document' , schema = schema' + , types = collectReferencedTypes schema' , rules = rules' } go definition' accumulator = (accumulator ><) <$> definition definition' definition :: forall m. Definition -> ValidateT m definition = \case - definition'@(ExecutableDefinition executableDefinition' _) -> do + definition'@(ExecutableDefinition executableDefinition') -> do applied <- applyRules definition' children <- executableDefinition executableDefinition' pure $ children >< applied definition' -> applyRules definition' where - applyRules definition' = foldr (ruleFilter definition') Seq.empty - <$> asks rules - ruleFilter definition' (DefinitionRule rule) accumulator - | Just message' <- rule definition' = - accumulator |> Error - { message = message' - , locations = [definitionLocation definition'] - , path = [] - } - | otherwise = accumulator - definitionLocation (ExecutableDefinition _ location) = location + applyRules definition' = + asks rules >>= foldM (ruleFilter definition') Seq.empty + ruleFilter definition' accumulator (DefinitionRule rule) = + flip mapReaderT (rule definition') $ \case + Just message' -> + pure $ accumulator |> Error + { message = message' + , locations = [definitionLocation definition'] + , path = [] + } + Nothing -> pure accumulator + ruleFilter _ accumulator _ = pure accumulator + definitionLocation (ExecutableDefinition executableDefinition') + | DefinitionOperation definitionOperation <- executableDefinition' + , SelectionSet _ location <- definitionOperation = location + | DefinitionOperation definitionOperation <- executableDefinition' + , OperationDefinition _ _ _ _ _ location <- definitionOperation = + location + | DefinitionFragment fragmentDefinition' <- executableDefinition' + , FragmentDefinition _ _ _ _ location <- fragmentDefinition' = location definitionLocation (TypeSystemDefinition _ location) = location definitionLocation (TypeSystemExtension _ location) = location @@ -88,10 +95,21 @@ executableDefinition (DefinitionFragment definition') = fragmentDefinition definition' operationDefinition :: forall m. OperationDefinition -> ValidateT m -operationDefinition (SelectionSet _operation) = - pure Seq.empty -operationDefinition (OperationDefinition _type _name _variables _directives _selection) = - pure Seq.empty +operationDefinition operation = + asks rules >>= foldM (ruleFilter operation) Seq.empty + where + ruleFilter definition' accumulator (OperationDefinitionRule rule) = + flip mapReaderT (rule definition') $ \case + Just message' -> + pure $ accumulator |> Error + { message = message' + , locations = [definitionLocation operation] + , path = [] + } + Nothing -> pure accumulator + ruleFilter _ accumulator _ = pure accumulator + definitionLocation (SelectionSet _ location) = location + definitionLocation (OperationDefinition _ _ _ _ _ location) = location fragmentDefinition :: forall m. FragmentDefinition -> ValidateT m fragmentDefinition _fragment = pure Seq.empty diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 9faaedd..a3b1a59 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -2,30 +2,99 @@ v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at https://mozilla.org/MPL/2.0/. -} +{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE OverloadedStrings #-} + -- | This module contains default rules defined in the GraphQL specification. module Language.GraphQL.Validate.Rules - ( Rule(..) - , executableDefinitionsRule + ( executableDefinitionsRule , specifiedRules ) where +import Control.Monad (foldM) +import Control.Monad.Trans.Class (MonadTrans(..)) +import Control.Monad.Trans.Reader (asks) +import Control.Monad.Trans.State (evalStateT, gets, modify) +import qualified Data.HashSet as HashSet +import qualified Data.Text as Text import Language.GraphQL.AST.Document - --- | 'Rule' assigns a function to each AST node that can be validated. If the --- validation fails, the function should return an error message, or 'Nothing' --- otherwise. -newtype Rule - = DefinitionRule (Definition -> Maybe String) +import Language.GraphQL.Type.Internal +import qualified Language.GraphQL.Type.Schema as Schema +import Language.GraphQL.Validate.Validation -- | Default rules given in the specification. -specifiedRules :: [Rule] +specifiedRules :: forall m. [Rule m] specifiedRules = [ executableDefinitionsRule + , singleFieldSubscriptionsRule ] -- | Definition must be OperationDefinition or FragmentDefinition. -executableDefinitionsRule :: Rule +executableDefinitionsRule :: forall m. Rule m executableDefinitionsRule = DefinitionRule go where - go (ExecutableDefinition _definition _) = Nothing - go _ = Just "Definition must be OperationDefinition or FragmentDefinition." + go :: Definition -> RuleT m + go (ExecutableDefinition _) = lift Nothing + go _ = pure + "Definition must be OperationDefinition or FragmentDefinition." + +-- | Subscription operations must have exactly one root field. +singleFieldSubscriptionsRule :: forall m. Rule m +singleFieldSubscriptionsRule = OperationDefinitionRule go + where + go (OperationDefinition Subscription name' _ _ rootFields _) = do + groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty + case HashSet.size groupedFieldSet of + 1 -> lift Nothing + _ + | Just name <- name' -> pure $ unwords + [ "Subscription" + , Text.unpack name + , "must select only one top level field." + ] + | otherwise -> pure + "Anonymous Subscription must select only one top level field." + go _ = lift Nothing + collectFields selectionSet = foldM forEach HashSet.empty selectionSet + forEach accumulator (Field alias name _ directives _) + | any skip directives = pure accumulator + | Just aliasedName <- alias = pure + $ HashSet.insert aliasedName accumulator + | otherwise = pure $ HashSet.insert name accumulator + forEach accumulator (FragmentSpread fragmentName directives) + | any skip directives = pure accumulator + | otherwise = do + inVisitetFragments <- gets $ HashSet.member fragmentName + if inVisitetFragments + then pure accumulator + else collectFromSpread fragmentName accumulator + forEach accumulator (InlineFragment typeCondition' directives selectionSet) + | any skip directives = pure accumulator + | Just typeCondition <- typeCondition' = + collectFromFragment typeCondition selectionSet accumulator + | otherwise = HashSet.union accumulator + <$> collectFields selectionSet + skip (Directive "skip" [Argument "if" (Boolean True)]) = True + skip (Directive "include" [Argument "if" (Boolean False)]) = True + skip _ = False + findFragmentDefinition (ExecutableDefinition executableDefinition) Nothing + | DefinitionFragment fragmentDefinition <- executableDefinition = + Just fragmentDefinition + findFragmentDefinition _ accumulator = accumulator + collectFromFragment typeCondition selectionSet accumulator = do + types' <- lift $ asks types + schema' <- lift $ asks schema + case lookupTypeCondition typeCondition types' of + Nothing -> pure accumulator + Just compositeType + | Just objectType <- Schema.subscription schema' + , True <- doesFragmentTypeApply compositeType objectType -> + HashSet.union accumulator<$> collectFields selectionSet + | otherwise -> pure accumulator + collectFromSpread fragmentName accumulator = do + modify $ HashSet.insert fragmentName + ast' <- lift $ asks ast + case foldr findFragmentDefinition Nothing ast' of + Nothing -> pure accumulator + Just (FragmentDefinition _ typeCondition _ selectionSet _) -> + collectFromFragment typeCondition selectionSet accumulator diff --git a/src/Language/GraphQL/Validate/Validation.hs b/src/Language/GraphQL/Validate/Validation.hs new file mode 100644 index 0000000..f6edc7a --- /dev/null +++ b/src/Language/GraphQL/Validate/Validation.hs @@ -0,0 +1,34 @@ +{- This Source Code Form is subject to the terms of the Mozilla Public License, + v. 2.0. If a copy of the MPL was not distributed with this file, You can + obtain one at https://mozilla.org/MPL/2.0/. -} + +-- | Definitions used by the validation rules and the validator itself. +module Language.GraphQL.Validate.Validation + ( Validation(..) + , Rule(..) + , RuleT + ) where + +import Control.Monad.Trans.Reader (ReaderT(..)) +import Data.HashMap.Strict (HashMap) +import Language.GraphQL.AST.Document +import Language.GraphQL.Type.Schema (Schema) +import qualified Language.GraphQL.Type.Schema as Schema + +-- | Validation rule context. +data Validation m = Validation + { ast :: Document + , schema :: Schema m + , types :: HashMap Name (Schema.Type m) + , rules :: [Rule m] + } + +-- | 'Rule' assigns a function to each AST node that can be validated. If the +-- validation fails, the function should return an error message, or 'Nothing' +-- otherwise. +data Rule m + = DefinitionRule (Definition -> RuleT m) + | OperationDefinitionRule (OperationDefinition -> RuleT m) + +-- | Monad transformer used by the rules. +type RuleT m = ReaderT (Validation m) Maybe String diff --git a/tests/Language/GraphQL/AST/EncoderSpec.hs b/tests/Language/GraphQL/AST/EncoderSpec.hs index 71ee948..5fa3706 100644 --- a/tests/Language/GraphQL/AST/EncoderSpec.hs +++ b/tests/Language/GraphQL/AST/EncoderSpec.hs @@ -123,7 +123,9 @@ spec = do it "indents block strings in arguments" $ let arguments = [Argument "message" (String "line1\nline2")] field = Field Nothing "field" arguments [] [] - operation = DefinitionOperation $ SelectionSet $ pure field + operation = DefinitionOperation + $ SelectionSet (pure field) + $ Location 0 0 in definition pretty operation `shouldBe` [r|{ field(message: """ line1 diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs index f84322d..c463dd9 100644 --- a/tests/Language/GraphQL/ValidateSpec.hs +++ b/tests/Language/GraphQL/ValidateSpec.hs @@ -148,7 +148,7 @@ validate queryString = spec :: Spec spec = - describe "document" $ + describe "document" $ do it "rejects type definitions" $ let queryString = [r| query getDogName { @@ -169,3 +169,43 @@ spec = , path = [] } in validate queryString `shouldBe` Seq.singleton expected + + it "rejects multiple subscription root fields" $ + let queryString = [r| + subscription sub { + newMessage { + body + sender + } + disallowedSecondRootField + } + |] + expected = Error + { message = + "Subscription sub must select only one top level field." + , locations = [AST.Location 2 15] + , path = [] + } + in validate queryString `shouldBe` Seq.singleton expected + + it "rejects multiple subscription root fields coming from a fragment" $ + let queryString = [r| + subscription sub { + ...multipleSubscriptions + } + + fragment multipleSubscriptions on Subscription { + newMessage { + body + sender + } + disallowedSecondRootField + } + |] + expected = Error + { message = + "Subscription sub must select only one top level field." + , locations = [AST.Location 2 15] + , path = [] + } + in validate queryString `shouldBe` Seq.singleton expected