summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/Language/GraphQL/Validate/Rules.hs33
-rw-r--r--tests/Language/GraphQL/ValidateSpec.hs36
2 files changed, 52 insertions, 17 deletions
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs
index b794b64..c67df1c 100644
--- a/src/Language/GraphQL/Validate/Rules.hs
+++ b/src/Language/GraphQL/Validate/Rules.hs
@@ -50,6 +50,7 @@ import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
import Data.List (groupBy, sortBy, sortOn)
import Data.Maybe (isNothing, mapMaybe)
+import Data.List.NonEmpty (NonEmpty)
import Data.Ord (comparing)
import Data.Sequence (Seq(..), (|>))
import qualified Data.Sequence as Seq
@@ -127,10 +128,10 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
1 -> lift mempty
_
| Just name <- name' -> pure $ Error
- { message = unwords
- [ "Subscription"
+ { message = concat
+ [ "Subscription \""
, Text.unpack name
- , "must select only one top level field."
+ , "\" must select only one top level field."
]
, locations = [location']
}
@@ -172,10 +173,6 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean False == argumentValue
skip _ = False
- findFragmentDefinition (Full.ExecutableDefinition executableDefinition) Nothing
- | Full.DefinitionFragment fragmentDefinition <- executableDefinition =
- Just fragmentDefinition
- findFragmentDefinition _ accumulator = accumulator
collectFromFragment typeCondition selectionSet accumulator = do
types' <- lift $ asks $ Schema.types . schema
schema' <- lift $ asks schema
@@ -189,7 +186,7 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
collectFromSpread fragmentName accumulator = do
modify $ HashSet.insert fragmentName
ast' <- lift $ asks ast
- case foldr findFragmentDefinition Nothing ast' of
+ case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator
Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) ->
collectFromFragment typeCondition selectionSet accumulator
@@ -493,18 +490,24 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case
(accumulator <>) <$> collectFields selections
forField accumulator (Full.Field _ _ _ _ selections _) =
(accumulator <>) <$> collectFields selections
- findFragmentDefinition n (Full.ExecutableDefinition executableDefinition) Nothing
- | Full.DefinitionFragment fragmentDefinition <- executableDefinition
- , Full.FragmentDefinition fragmentName _ _ _ _ <- fragmentDefinition
- , fragmentName == n = Just fragmentDefinition
- findFragmentDefinition _ _ accumulator = accumulator
- collectFromSpread _fragmentName accumulator = do
+ collectFromSpread fragmentName accumulator = do
ast' <- lift $ asks ast
- case foldr (findFragmentDefinition _fragmentName) Nothing ast' of
+ case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator
Just (Full.FragmentDefinition _ _ _ selections _) ->
(accumulator <>) <$> collectFields selections
+findFragmentDefinition :: Text
+ -> NonEmpty Full.Definition
+ -> Maybe Full.FragmentDefinition
+findFragmentDefinition fragmentName = foldr compareDefinition Nothing
+ where
+ compareDefinition (Full.ExecutableDefinition executableDefinition) Nothing
+ | Full.DefinitionFragment fragmentDefinition <- executableDefinition
+ , Full.FragmentDefinition anotherName _ _ _ _ <- fragmentDefinition
+ , anotherName == fragmentName = Just fragmentDefinition
+ compareDefinition _ accumulator = accumulator
+
-- | Fields and directives treat arguments as a mapping of argument name to
-- value. More than one argument with the same name in an argument set is
-- ambiguous and invalid.
diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs
index 3bfa018..318045c 100644
--- a/tests/Language/GraphQL/ValidateSpec.hs
+++ b/tests/Language/GraphQL/ValidateSpec.hs
@@ -92,6 +92,7 @@ petType = InterfaceType "Pet" Nothing []
subscriptionType :: ObjectType IO
subscriptionType = ObjectType "Subscription" Nothing [] $ HashMap.fromList
[ ("newMessage", newMessageResolver)
+ , ("disallowedSecondRootField", newMessageResolver)
]
where
newMessageField = Field Nothing (Out.NonNullObjectType messageType) mempty
@@ -165,7 +166,8 @@ spec =
|]
expected = Error
{ message =
- "Subscription sub must select only one top level field."
+ "Subscription \"sub\" must select only one top level \
+ \field."
, locations = [AST.Location 2 15]
}
in validate queryString `shouldContain` [expected]
@@ -186,7 +188,8 @@ spec =
|]
expected = Error
{ message =
- "Subscription sub must select only one top level field."
+ "Subscription \"sub\" must select only one top level \
+ \field."
, locations = [AST.Location 2 15]
}
in validate queryString `shouldContain` [expected]
@@ -631,3 +634,32 @@ spec =
, locations = [AST.Location 3 34]
}
in validate queryString `shouldBe` [expected]
+
+ it "finds corresponding subscription fragment" $
+ let queryString = [r|
+ subscription sub {
+ ...anotherSubscription
+ ...multipleSubscriptions
+ }
+ fragment multipleSubscriptions on Subscription {
+ newMessage {
+ body
+ }
+ disallowedSecondRootField {
+ sender
+ }
+ }
+ fragment anotherSubscription on Subscription {
+ newMessage {
+ body
+ sender
+ }
+ }
+ |]
+ expected = Error
+ { message =
+ "Subscription \"sub\" must select only one top level \
+ \field."
+ , locations = [AST.Location 2 15]
+ }
+ in validate queryString `shouldBe` [expected]