Validate single root field in subscriptions

This commit is contained in:
Eugen Wissner 2020-08-25 21:03:42 +02:00
parent 54dbf1df16
commit 7355533268
13 changed files with 301 additions and 120 deletions

View File

@ -9,6 +9,14 @@ and this project adheres to
## [Unreleased] ## [Unreleased]
## Changed ## Changed
- `Test.Hspec.GraphQL.*`: replace `IO` in the resolver with any `MonadCatch`. - `Test.Hspec.GraphQL.*`: replace `IO` in the resolver with any `MonadCatch`.
- The `Location` argument of `AST.Document.Definition.ExecutableDefinition` was
moved to `OperationDefinition` and `FragmentDefinition` since these are the
actual elements that have a location in the document.
- `Validate.Rules` get the whole validation context (AST and schema).
## Added
- `Validate.Validation` contains data structures and functions used by the
validator and concretet rules.
## [0.9.0.0] - 2020-07-24 ## [0.9.0.0] - 2020-07-24
## Fixed ## Fixed

View File

@ -4,7 +4,7 @@ cabal-version: 1.12
-- --
-- see: https://github.com/sol/hpack -- see: https://github.com/sol/hpack
-- --
-- hash: 1d8c32c00a882ccd1fefc4c083d5fe4e83a1825fbf8e0dcfd551ff2c8cd2dda0 -- hash: 59e2949d07cb5e678b493b77771db1bd64947de480f3da93ca07b3f2458cc495
name: graphql name: graphql
version: 0.9.0.0 version: 0.9.0.0
@ -50,6 +50,7 @@ library
Language.GraphQL.Type.Out Language.GraphQL.Type.Out
Language.GraphQL.Type.Schema Language.GraphQL.Type.Schema
Language.GraphQL.Validate Language.GraphQL.Validate
Language.GraphQL.Validate.Validation
Test.Hspec.GraphQL Test.Hspec.GraphQL
other-modules: other-modules:
Language.GraphQL.Execute.Execution Language.GraphQL.Execute.Execution

View File

@ -69,7 +69,7 @@ type Document = NonEmpty Definition
-- | All kinds of definitions that can occur in a GraphQL document. -- | All kinds of definitions that can occur in a GraphQL document.
data Definition data Definition
= ExecutableDefinition ExecutableDefinition Location = ExecutableDefinition ExecutableDefinition
| TypeSystemDefinition TypeSystemDefinition Location | TypeSystemDefinition TypeSystemDefinition Location
| TypeSystemExtension TypeSystemExtension Location | TypeSystemExtension TypeSystemExtension Location
deriving (Eq, Show) deriving (Eq, Show)
@ -84,13 +84,14 @@ data ExecutableDefinition
-- | Operation definition. -- | Operation definition.
data OperationDefinition data OperationDefinition
= SelectionSet SelectionSet = SelectionSet SelectionSet Location
| OperationDefinition | OperationDefinition
OperationType OperationType
(Maybe Name) (Maybe Name)
[VariableDefinition] [VariableDefinition]
[Directive] [Directive]
SelectionSet SelectionSet
Location
deriving (Eq, Show) deriving (Eq, Show)
-- | GraphQL has 3 operation types: -- | GraphQL has 3 operation types:
@ -195,7 +196,7 @@ type Alias = Name
-- | Fragment definition. -- | Fragment definition.
data FragmentDefinition data FragmentDefinition
= FragmentDefinition Name TypeCondition [Directive] SelectionSet = FragmentDefinition Name TypeCondition [Directive] SelectionSet Location
deriving (Eq, Show) deriving (Eq, Show)
-- | Type condition. -- | Type condition.

View File

@ -50,8 +50,8 @@ document formatter defs
| Minified <-formatter = Lazy.Text.snoc (mconcat encodeDocument) '\n' | Minified <-formatter = Lazy.Text.snoc (mconcat encodeDocument) '\n'
where where
encodeDocument = foldr executableDefinition [] defs encodeDocument = foldr executableDefinition [] defs
executableDefinition (ExecutableDefinition x _) acc = executableDefinition (ExecutableDefinition executableDefinition') acc =
definition formatter x : acc definition formatter executableDefinition' : acc
executableDefinition _ acc = acc executableDefinition _ acc = acc
-- | Converts a t'ExecutableDefinition' into a string. -- | Converts a t'ExecutableDefinition' into a string.
@ -68,12 +68,12 @@ definition formatter x
-- | Converts a 'OperationDefinition into a string. -- | Converts a 'OperationDefinition into a string.
operationDefinition :: Formatter -> OperationDefinition -> Lazy.Text operationDefinition :: Formatter -> OperationDefinition -> Lazy.Text
operationDefinition formatter = \case operationDefinition formatter = \case
SelectionSet sels -> selectionSet formatter sels SelectionSet sels _ -> selectionSet formatter sels
OperationDefinition Query name vars dirs sels -> OperationDefinition Query name vars dirs sels _ ->
"query " <> node formatter name vars dirs sels "query " <> node formatter name vars dirs sels
OperationDefinition Mutation name vars dirs sels -> OperationDefinition Mutation name vars dirs sels _ ->
"mutation " <> node formatter name vars dirs sels "mutation " <> node formatter name vars dirs sels
OperationDefinition Subscription name vars dirs sels -> OperationDefinition Subscription name vars dirs sels _ ->
"subscription " <> node formatter name vars dirs sels "subscription " <> node formatter name vars dirs sels
-- | Converts a Query or Mutation into a string. -- | Converts a Query or Mutation into a string.
@ -190,7 +190,7 @@ inlineFragment formatter tc dirs sels = "... on "
<> selectionSet formatter sels <> selectionSet formatter sels
fragmentDefinition :: Formatter -> FragmentDefinition -> Lazy.Text fragmentDefinition :: Formatter -> FragmentDefinition -> Lazy.Text
fragmentDefinition formatter (FragmentDefinition name tc dirs sels) fragmentDefinition formatter (FragmentDefinition name tc dirs sels _)
= "fragment " <> Lazy.Text.fromStrict name = "fragment " <> Lazy.Text.fromStrict name
<> " on " <> Lazy.Text.fromStrict tc <> " on " <> Lazy.Text.fromStrict tc
<> optempty (directives formatter) dirs <> optempty (directives formatter) dirs

View File

@ -21,7 +21,8 @@ import Language.GraphQL.AST.DirectiveLocation
import Language.GraphQL.AST.Document import Language.GraphQL.AST.Document
import Language.GraphQL.AST.Lexer import Language.GraphQL.AST.Lexer
import Text.Megaparsec import Text.Megaparsec
( SourcePos(..) ( MonadParsec(..)
, SourcePos(..)
, getSourcePos , getSourcePos
, lookAhead , lookAhead
, option , option
@ -37,15 +38,11 @@ document = unicodeBOM
*> lexeme (NonEmpty.some definition) *> lexeme (NonEmpty.some definition)
definition :: Parser Definition definition :: Parser Definition
definition = executableDefinition' definition = ExecutableDefinition <$> executableDefinition
<|> typeSystemDefinition' <|> typeSystemDefinition'
<|> typeSystemExtension' <|> typeSystemExtension'
<?> "Definition" <?> "Definition"
where where
executableDefinition' = do
location <- getLocation
definition' <- executableDefinition
pure $ ExecutableDefinition definition' location
typeSystemDefinition' = do typeSystemDefinition' = do
location <- getLocation location <- getLocation
definition' <- typeSystemDefinition definition' <- typeSystemDefinition
@ -349,16 +346,22 @@ operationTypeDefinition = OperationTypeDefinition
<?> "OperationTypeDefinition" <?> "OperationTypeDefinition"
operationDefinition :: Parser OperationDefinition operationDefinition :: Parser OperationDefinition
operationDefinition = SelectionSet <$> selectionSet operationDefinition = shorthand
<|> operationDefinition' <|> operationDefinition'
<?> "OperationDefinition" <?> "OperationDefinition"
where where
operationDefinition' shorthand = do
= OperationDefinition <$> operationType location <- getLocation
<*> optional name selectionSet' <- selectionSet
<*> variableDefinitions pure $ SelectionSet selectionSet' location
<*> directives operationDefinition' = do
<*> selectionSet location <- getLocation
operationType' <- operationType
operationName <- optional name
variableDefinitions' <- variableDefinitions
directives' <- directives
selectionSet' <- selectionSet
pure $ OperationDefinition operationType' operationName variableDefinitions' directives' selectionSet' location
operationType :: Parser OperationType operationType :: Parser OperationType
operationType = Query <$ symbol "query" operationType = Query <$ symbol "query"
@ -412,13 +415,15 @@ inlineFragment = InlineFragment
<?> "InlineFragment" <?> "InlineFragment"
fragmentDefinition :: Parser FragmentDefinition fragmentDefinition :: Parser FragmentDefinition
fragmentDefinition = FragmentDefinition fragmentDefinition = label "FragmentDefinition" $ do
<$ symbol "fragment" location <- getLocation
<*> name _ <- symbol "fragment"
<*> typeCondition fragmentName' <- name
<*> directives typeCondition' <- typeCondition
<*> selectionSet directives' <- directives
<?> "FragmentDefinition" selectionSet' <- selectionSet
pure $ FragmentDefinition
fragmentName' typeCondition' directives' selectionSet' location
fragmentName :: Parser Name fragmentName :: Parser Name
fragmentName = but (symbol "on") *> name <?> "FragmentName" fragmentName = but (symbol "on") *> name <?> "FragmentName"

View File

@ -83,30 +83,6 @@ resolveAbstractType abstractType values'
_ -> pure Nothing _ -> pure Nothing
| otherwise = pure Nothing | otherwise = pure Nothing
doesFragmentTypeApply :: forall m
. CompositeType m
-> Out.ObjectType m
-> Bool
doesFragmentTypeApply (CompositeObjectType fragmentType) objectType =
fragmentType == objectType
doesFragmentTypeApply (CompositeInterfaceType fragmentType) objectType =
instanceOf objectType $ AbstractInterfaceType fragmentType
doesFragmentTypeApply (CompositeUnionType fragmentType) objectType =
instanceOf objectType $ AbstractUnionType fragmentType
instanceOf :: forall m. Out.ObjectType m -> AbstractType m -> Bool
instanceOf objectType (AbstractInterfaceType interfaceType) =
let Out.ObjectType _ _ interfaces _ = objectType
in foldr go False interfaces
where
go objectInterfaceType@(Out.InterfaceType _ _ interfaces _) acc =
acc || foldr go (interfaceType == objectInterfaceType) interfaces
instanceOf objectType (AbstractUnionType unionType) =
let Out.UnionType _ _ members = unionType
in foldr go False members
where
go unionMemberType acc = acc || objectType == unionMemberType
executeField :: (MonadCatch m, Serialize a) executeField :: (MonadCatch m, Serialize a)
=> Out.Resolver m => Out.Resolver m
-> Type.Value -> Type.Value

View File

@ -255,18 +255,18 @@ defragment ast =
in (, fragmentTable) <$> maybe emptyDocument Right nonEmptyOperations in (, fragmentTable) <$> maybe emptyDocument Right nonEmptyOperations
where where
defragment' definition (operations, fragments') defragment' definition (operations, fragments')
| (Full.ExecutableDefinition executable _) <- definition | (Full.ExecutableDefinition executable) <- definition
, (Full.DefinitionOperation operation') <- executable = , (Full.DefinitionOperation operation') <- executable =
(transform operation' : operations, fragments') (transform operation' : operations, fragments')
| (Full.ExecutableDefinition executable _) <- definition | (Full.ExecutableDefinition executable) <- definition
, (Full.DefinitionFragment fragment) <- executable , (Full.DefinitionFragment fragment) <- executable
, (Full.FragmentDefinition name _ _ _) <- fragment = , (Full.FragmentDefinition name _ _ _ _) <- fragment =
(operations, HashMap.insert name fragment fragments') (operations, HashMap.insert name fragment fragments')
defragment' _ acc = acc defragment' _ acc = acc
transform = \case transform = \case
Full.OperationDefinition type' name variables directives' selections -> Full.OperationDefinition type' name variables directives' selections _ ->
OperationDefinition type' name variables directives' selections OperationDefinition type' name variables directives' selections
Full.SelectionSet selectionSet -> Full.SelectionSet selectionSet _ ->
OperationDefinition Full.Query Nothing mempty mempty selectionSet OperationDefinition Full.Query Nothing mempty mempty selectionSet
-- * Operation -- * Operation
@ -324,8 +324,8 @@ selection (Full.InlineFragment type' directives' selections) = do
case type' of case type' of
Nothing -> pure $ Left fragmentSelectionSet Nothing -> pure $ Left fragmentSelectionSet
Just typeName -> do Just typeName -> do
typeCondition' <- lookupTypeCondition typeName types' <- gets types
case typeCondition' of case lookupTypeCondition typeName types' of
Just typeCondition -> pure $ Just typeCondition -> pure $
selectionFragment typeCondition fragmentSelectionSet selectionFragment typeCondition fragmentSelectionSet
Nothing -> pure $ Left mempty Nothing -> pure $ Left mempty
@ -364,29 +364,17 @@ collectFragments = do
_ <- fragmentDefinition nextValue _ <- fragmentDefinition nextValue
collectFragments collectFragments
lookupTypeCondition :: Full.Name -> State (Replacement m) (Maybe (CompositeType m))
lookupTypeCondition type' = do
types' <- gets types
case HashMap.lookup type' types' of
Just (ObjectType objectType) ->
lift $ pure $ Just $ CompositeObjectType objectType
Just (UnionType unionType) ->
lift $ pure $ Just $ CompositeUnionType unionType
Just (InterfaceType interfaceType) ->
lift $ pure $ Just $ CompositeInterfaceType interfaceType
_ -> lift $ pure Nothing
fragmentDefinition fragmentDefinition
:: Full.FragmentDefinition :: Full.FragmentDefinition
-> State (Replacement m) (Maybe (Fragment m)) -> State (Replacement m) (Maybe (Fragment m))
fragmentDefinition (Full.FragmentDefinition name type' _ selections) = do fragmentDefinition (Full.FragmentDefinition name type' _ selections _) = do
modify deleteFragmentDefinition modify deleteFragmentDefinition
fragmentSelection <- appendSelection selections fragmentSelection <- appendSelection selections
compositeType <- lookupTypeCondition type' types' <- gets types
case compositeType of case lookupTypeCondition type' types' of
Just compositeType' -> do Just compositeType -> do
let newValue = Fragment compositeType' fragmentSelection let newValue = Fragment compositeType fragmentSelection
modify $ insertFragment newValue modify $ insertFragment newValue
lift $ pure $ Just newValue lift $ pure $ Just newValue
_ -> lift $ pure Nothing _ -> lift $ pure Nothing

View File

@ -8,6 +8,9 @@ module Language.GraphQL.Type.Internal
( AbstractType(..) ( AbstractType(..)
, CompositeType(..) , CompositeType(..)
, collectReferencedTypes , collectReferencedTypes
, doesFragmentTypeApply
, instanceOf
, lookupTypeCondition
) where ) where
import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict (HashMap)
@ -89,3 +92,39 @@ collectReferencedTypes schema =
polymorphicTraverser interfaces fields polymorphicTraverser interfaces fields
= flip (foldr visitFields) fields = flip (foldr visitFields) fields
. flip (foldr traverseInterfaceType) interfaces . flip (foldr traverseInterfaceType) interfaces
doesFragmentTypeApply :: forall m
. CompositeType m
-> Out.ObjectType m
-> Bool
doesFragmentTypeApply (CompositeObjectType fragmentType) objectType =
fragmentType == objectType
doesFragmentTypeApply (CompositeInterfaceType fragmentType) objectType =
instanceOf objectType $ AbstractInterfaceType fragmentType
doesFragmentTypeApply (CompositeUnionType fragmentType) objectType =
instanceOf objectType $ AbstractUnionType fragmentType
instanceOf :: forall m. Out.ObjectType m -> AbstractType m -> Bool
instanceOf objectType (AbstractInterfaceType interfaceType) =
let Out.ObjectType _ _ interfaces _ = objectType
in foldr go False interfaces
where
go objectInterfaceType@(Out.InterfaceType _ _ interfaces _) acc =
acc || foldr go (interfaceType == objectInterfaceType) interfaces
instanceOf objectType (AbstractUnionType unionType) =
let Out.UnionType _ _ members = unionType
in foldr go False members
where
go unionMemberType acc = acc || objectType == unionMemberType
lookupTypeCondition :: forall m
. Name
-> HashMap Name (Type m)
-> Maybe (CompositeType m)
lookupTypeCondition type' types' =
case HashMap.lookup type' types' of
Just (ObjectType objectType) -> Just $ CompositeObjectType objectType
Just (UnionType unionType) -> Just $ CompositeUnionType unionType
Just (InterfaceType interfaceType) ->
Just $ CompositeInterfaceType interfaceType
_ -> Nothing

View File

@ -13,22 +13,19 @@ module Language.GraphQL.Validate
, module Language.GraphQL.Validate.Rules , module Language.GraphQL.Validate.Rules
) where ) where
import Control.Monad.Trans.Reader (Reader, asks, runReader) import Control.Monad (foldM)
import Control.Monad.Trans.Reader (Reader, asks, mapReaderT, runReader)
import Data.Foldable (foldrM) import Data.Foldable (foldrM)
import Data.Sequence (Seq(..), (><), (|>)) import Data.Sequence (Seq(..), (><), (|>))
import qualified Data.Sequence as Seq import qualified Data.Sequence as Seq
import Data.Text (Text) import Data.Text (Text)
import Language.GraphQL.AST.Document import Language.GraphQL.AST.Document
import Language.GraphQL.Type.Schema import Language.GraphQL.Type.Internal
import Language.GraphQL.Type.Schema (Schema(..))
import Language.GraphQL.Validate.Rules import Language.GraphQL.Validate.Rules
import Language.GraphQL.Validate.Validation
data Context m = Context type ValidateT m = Reader (Validation m) (Seq Error)
{ ast :: Document
, schema :: Schema m
, rules :: [Rule]
}
type ValidateT m = Reader (Context m) (Seq Error)
-- | If an error can be associated to a particular field in the GraphQL result, -- | If an error can be associated to a particular field in the GraphQL result,
-- it must contain an entry with the key path that details the path of the -- it must contain an entry with the key path that details the path of the
@ -48,36 +45,46 @@ data Error = Error
-- | Validates a document and returns a list of found errors. If the returned -- | Validates a document and returns a list of found errors. If the returned
-- list is empty, the document is valid. -- list is empty, the document is valid.
document :: forall m. Schema m -> [Rule] -> Document -> Seq Error document :: forall m. Schema m -> [Rule m] -> Document -> Seq Error
document schema' rules' document' = document schema' rules' document' =
runReader (foldrM go Seq.empty document') context runReader (foldrM go Seq.empty document') context
where where
context = Context context = Validation
{ ast = document' { ast = document'
, schema = schema' , schema = schema'
, types = collectReferencedTypes schema'
, rules = rules' , rules = rules'
} }
go definition' accumulator = (accumulator ><) <$> definition definition' go definition' accumulator = (accumulator ><) <$> definition definition'
definition :: forall m. Definition -> ValidateT m definition :: forall m. Definition -> ValidateT m
definition = \case definition = \case
definition'@(ExecutableDefinition executableDefinition' _) -> do definition'@(ExecutableDefinition executableDefinition') -> do
applied <- applyRules definition' applied <- applyRules definition'
children <- executableDefinition executableDefinition' children <- executableDefinition executableDefinition'
pure $ children >< applied pure $ children >< applied
definition' -> applyRules definition' definition' -> applyRules definition'
where where
applyRules definition' = foldr (ruleFilter definition') Seq.empty applyRules definition' =
<$> asks rules asks rules >>= foldM (ruleFilter definition') Seq.empty
ruleFilter definition' (DefinitionRule rule) accumulator ruleFilter definition' accumulator (DefinitionRule rule) =
| Just message' <- rule definition' = flip mapReaderT (rule definition') $ \case
accumulator |> Error Just message' ->
pure $ accumulator |> Error
{ message = message' { message = message'
, locations = [definitionLocation definition'] , locations = [definitionLocation definition']
, path = [] , path = []
} }
| otherwise = accumulator Nothing -> pure accumulator
definitionLocation (ExecutableDefinition _ location) = location ruleFilter _ accumulator _ = pure accumulator
definitionLocation (ExecutableDefinition executableDefinition')
| DefinitionOperation definitionOperation <- executableDefinition'
, SelectionSet _ location <- definitionOperation = location
| DefinitionOperation definitionOperation <- executableDefinition'
, OperationDefinition _ _ _ _ _ location <- definitionOperation =
location
| DefinitionFragment fragmentDefinition' <- executableDefinition'
, FragmentDefinition _ _ _ _ location <- fragmentDefinition' = location
definitionLocation (TypeSystemDefinition _ location) = location definitionLocation (TypeSystemDefinition _ location) = location
definitionLocation (TypeSystemExtension _ location) = location definitionLocation (TypeSystemExtension _ location) = location
@ -88,10 +95,21 @@ executableDefinition (DefinitionFragment definition') =
fragmentDefinition definition' fragmentDefinition definition'
operationDefinition :: forall m. OperationDefinition -> ValidateT m operationDefinition :: forall m. OperationDefinition -> ValidateT m
operationDefinition (SelectionSet _operation) = operationDefinition operation =
pure Seq.empty asks rules >>= foldM (ruleFilter operation) Seq.empty
operationDefinition (OperationDefinition _type _name _variables _directives _selection) = where
pure Seq.empty ruleFilter definition' accumulator (OperationDefinitionRule rule) =
flip mapReaderT (rule definition') $ \case
Just message' ->
pure $ accumulator |> Error
{ message = message'
, locations = [definitionLocation operation]
, path = []
}
Nothing -> pure accumulator
ruleFilter _ accumulator _ = pure accumulator
definitionLocation (SelectionSet _ location) = location
definitionLocation (OperationDefinition _ _ _ _ _ location) = location
fragmentDefinition :: forall m. FragmentDefinition -> ValidateT m fragmentDefinition :: forall m. FragmentDefinition -> ValidateT m
fragmentDefinition _fragment = pure Seq.empty fragmentDefinition _fragment = pure Seq.empty

View File

@ -2,30 +2,99 @@
v. 2.0. If a copy of the MPL was not distributed with this file, You can v. 2.0. If a copy of the MPL was not distributed with this file, You can
obtain one at https://mozilla.org/MPL/2.0/. -} obtain one at https://mozilla.org/MPL/2.0/. -}
{-# LANGUAGE ExplicitForAll #-}
{-# LANGUAGE OverloadedStrings #-}
-- | This module contains default rules defined in the GraphQL specification. -- | This module contains default rules defined in the GraphQL specification.
module Language.GraphQL.Validate.Rules module Language.GraphQL.Validate.Rules
( Rule(..) ( executableDefinitionsRule
, executableDefinitionsRule
, specifiedRules , specifiedRules
) where ) where
import Control.Monad (foldM)
import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Reader (asks)
import Control.Monad.Trans.State (evalStateT, gets, modify)
import qualified Data.HashSet as HashSet
import qualified Data.Text as Text
import Language.GraphQL.AST.Document import Language.GraphQL.AST.Document
import Language.GraphQL.Type.Internal
-- | 'Rule' assigns a function to each AST node that can be validated. If the import qualified Language.GraphQL.Type.Schema as Schema
-- validation fails, the function should return an error message, or 'Nothing' import Language.GraphQL.Validate.Validation
-- otherwise.
newtype Rule
= DefinitionRule (Definition -> Maybe String)
-- | Default rules given in the specification. -- | Default rules given in the specification.
specifiedRules :: [Rule] specifiedRules :: forall m. [Rule m]
specifiedRules = specifiedRules =
[ executableDefinitionsRule [ executableDefinitionsRule
, singleFieldSubscriptionsRule
] ]
-- | Definition must be OperationDefinition or FragmentDefinition. -- | Definition must be OperationDefinition or FragmentDefinition.
executableDefinitionsRule :: Rule executableDefinitionsRule :: forall m. Rule m
executableDefinitionsRule = DefinitionRule go executableDefinitionsRule = DefinitionRule go
where where
go (ExecutableDefinition _definition _) = Nothing go :: Definition -> RuleT m
go _ = Just "Definition must be OperationDefinition or FragmentDefinition." go (ExecutableDefinition _) = lift Nothing
go _ = pure
"Definition must be OperationDefinition or FragmentDefinition."
-- | Subscription operations must have exactly one root field.
singleFieldSubscriptionsRule :: forall m. Rule m
singleFieldSubscriptionsRule = OperationDefinitionRule go
where
go (OperationDefinition Subscription name' _ _ rootFields _) = do
groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty
case HashSet.size groupedFieldSet of
1 -> lift Nothing
_
| Just name <- name' -> pure $ unwords
[ "Subscription"
, Text.unpack name
, "must select only one top level field."
]
| otherwise -> pure
"Anonymous Subscription must select only one top level field."
go _ = lift Nothing
collectFields selectionSet = foldM forEach HashSet.empty selectionSet
forEach accumulator (Field alias name _ directives _)
| any skip directives = pure accumulator
| Just aliasedName <- alias = pure
$ HashSet.insert aliasedName accumulator
| otherwise = pure $ HashSet.insert name accumulator
forEach accumulator (FragmentSpread fragmentName directives)
| any skip directives = pure accumulator
| otherwise = do
inVisitetFragments <- gets $ HashSet.member fragmentName
if inVisitetFragments
then pure accumulator
else collectFromSpread fragmentName accumulator
forEach accumulator (InlineFragment typeCondition' directives selectionSet)
| any skip directives = pure accumulator
| Just typeCondition <- typeCondition' =
collectFromFragment typeCondition selectionSet accumulator
| otherwise = HashSet.union accumulator
<$> collectFields selectionSet
skip (Directive "skip" [Argument "if" (Boolean True)]) = True
skip (Directive "include" [Argument "if" (Boolean False)]) = True
skip _ = False
findFragmentDefinition (ExecutableDefinition executableDefinition) Nothing
| DefinitionFragment fragmentDefinition <- executableDefinition =
Just fragmentDefinition
findFragmentDefinition _ accumulator = accumulator
collectFromFragment typeCondition selectionSet accumulator = do
types' <- lift $ asks types
schema' <- lift $ asks schema
case lookupTypeCondition typeCondition types' of
Nothing -> pure accumulator
Just compositeType
| Just objectType <- Schema.subscription schema'
, True <- doesFragmentTypeApply compositeType objectType ->
HashSet.union accumulator<$> collectFields selectionSet
| otherwise -> pure accumulator
collectFromSpread fragmentName accumulator = do
modify $ HashSet.insert fragmentName
ast' <- lift $ asks ast
case foldr findFragmentDefinition Nothing ast' of
Nothing -> pure accumulator
Just (FragmentDefinition _ typeCondition _ selectionSet _) ->
collectFromFragment typeCondition selectionSet accumulator

View File

@ -0,0 +1,34 @@
{- This Source Code Form is subject to the terms of the Mozilla Public License,
v. 2.0. If a copy of the MPL was not distributed with this file, You can
obtain one at https://mozilla.org/MPL/2.0/. -}
-- | Definitions used by the validation rules and the validator itself.
module Language.GraphQL.Validate.Validation
( Validation(..)
, Rule(..)
, RuleT
) where
import Control.Monad.Trans.Reader (ReaderT(..))
import Data.HashMap.Strict (HashMap)
import Language.GraphQL.AST.Document
import Language.GraphQL.Type.Schema (Schema)
import qualified Language.GraphQL.Type.Schema as Schema
-- | Validation rule context.
data Validation m = Validation
{ ast :: Document
, schema :: Schema m
, types :: HashMap Name (Schema.Type m)
, rules :: [Rule m]
}
-- | 'Rule' assigns a function to each AST node that can be validated. If the
-- validation fails, the function should return an error message, or 'Nothing'
-- otherwise.
data Rule m
= DefinitionRule (Definition -> RuleT m)
| OperationDefinitionRule (OperationDefinition -> RuleT m)
-- | Monad transformer used by the rules.
type RuleT m = ReaderT (Validation m) Maybe String

View File

@ -123,7 +123,9 @@ spec = do
it "indents block strings in arguments" $ it "indents block strings in arguments" $
let arguments = [Argument "message" (String "line1\nline2")] let arguments = [Argument "message" (String "line1\nline2")]
field = Field Nothing "field" arguments [] [] field = Field Nothing "field" arguments [] []
operation = DefinitionOperation $ SelectionSet $ pure field operation = DefinitionOperation
$ SelectionSet (pure field)
$ Location 0 0
in definition pretty operation `shouldBe` [r|{ in definition pretty operation `shouldBe` [r|{
field(message: """ field(message: """
line1 line1

View File

@ -148,7 +148,7 @@ validate queryString =
spec :: Spec spec :: Spec
spec = spec =
describe "document" $ describe "document" $ do
it "rejects type definitions" $ it "rejects type definitions" $
let queryString = [r| let queryString = [r|
query getDogName { query getDogName {
@ -169,3 +169,43 @@ spec =
, path = [] , path = []
} }
in validate queryString `shouldBe` Seq.singleton expected in validate queryString `shouldBe` Seq.singleton expected
it "rejects multiple subscription root fields" $
let queryString = [r|
subscription sub {
newMessage {
body
sender
}
disallowedSecondRootField
}
|]
expected = Error
{ message =
"Subscription sub must select only one top level field."
, locations = [AST.Location 2 15]
, path = []
}
in validate queryString `shouldBe` Seq.singleton expected
it "rejects multiple subscription root fields coming from a fragment" $
let queryString = [r|
subscription sub {
...multipleSubscriptions
}
fragment multipleSubscriptions on Subscription {
newMessage {
body
sender
}
disallowedSecondRootField
}
|]
expected = Error
{ message =
"Subscription sub must select only one top level field."
, locations = [AST.Location 2 15]
, path = []
}
in validate queryString `shouldBe` Seq.singleton expected