From 4a3b4cb16d7da9c356b514ab48bdc0e527acd377 Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Fri, 6 Nov 2020 08:33:51 +0100 Subject: [PATCH] Fix singleFieldSubscriptionsRule fragment lookup singleFieldSubscriptionsRule picks up a wrong fragment definition. --- src/Language/GraphQL/Validate/Rules.hs | 33 ++++++++++++----------- tests/Language/GraphQL/ValidateSpec.hs | 36 ++++++++++++++++++++++++-- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index b794b64..c67df1c 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -50,6 +50,7 @@ import Data.HashSet (HashSet) import qualified Data.HashSet as HashSet import Data.List (groupBy, sortBy, sortOn) import Data.Maybe (isNothing, mapMaybe) +import Data.List.NonEmpty (NonEmpty) import Data.Ord (comparing) import Data.Sequence (Seq(..), (|>)) import qualified Data.Sequence as Seq @@ -127,10 +128,10 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case 1 -> lift mempty _ | Just name <- name' -> pure $ Error - { message = unwords - [ "Subscription" + { message = concat + [ "Subscription \"" , Text.unpack name - , "must select only one top level field." + , "\" must select only one top level field." ] , locations = [location'] } @@ -172,10 +173,6 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) = Full.Boolean False == argumentValue skip _ = False - findFragmentDefinition (Full.ExecutableDefinition executableDefinition) Nothing - | Full.DefinitionFragment fragmentDefinition <- executableDefinition = - Just fragmentDefinition - findFragmentDefinition _ accumulator = accumulator collectFromFragment typeCondition selectionSet accumulator = do types' <- lift $ asks $ Schema.types . schema schema' <- lift $ asks schema @@ -189,7 +186,7 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case collectFromSpread fragmentName accumulator = do modify $ HashSet.insert fragmentName ast' <- lift $ asks ast - case foldr findFragmentDefinition Nothing ast' of + case findFragmentDefinition fragmentName ast' of Nothing -> pure accumulator Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) -> collectFromFragment typeCondition selectionSet accumulator @@ -493,18 +490,24 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case (accumulator <>) <$> collectFields selections forField accumulator (Full.Field _ _ _ _ selections _) = (accumulator <>) <$> collectFields selections - findFragmentDefinition n (Full.ExecutableDefinition executableDefinition) Nothing - | Full.DefinitionFragment fragmentDefinition <- executableDefinition - , Full.FragmentDefinition fragmentName _ _ _ _ <- fragmentDefinition - , fragmentName == n = Just fragmentDefinition - findFragmentDefinition _ _ accumulator = accumulator - collectFromSpread _fragmentName accumulator = do + collectFromSpread fragmentName accumulator = do ast' <- lift $ asks ast - case foldr (findFragmentDefinition _fragmentName) Nothing ast' of + case findFragmentDefinition fragmentName ast' of Nothing -> pure accumulator Just (Full.FragmentDefinition _ _ _ selections _) -> (accumulator <>) <$> collectFields selections +findFragmentDefinition :: Text + -> NonEmpty Full.Definition + -> Maybe Full.FragmentDefinition +findFragmentDefinition fragmentName = foldr compareDefinition Nothing + where + compareDefinition (Full.ExecutableDefinition executableDefinition) Nothing + | Full.DefinitionFragment fragmentDefinition <- executableDefinition + , Full.FragmentDefinition anotherName _ _ _ _ <- fragmentDefinition + , anotherName == fragmentName = Just fragmentDefinition + compareDefinition _ accumulator = accumulator + -- | Fields and directives treat arguments as a mapping of argument name to -- value. More than one argument with the same name in an argument set is -- ambiguous and invalid. diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs index 3bfa018..318045c 100644 --- a/tests/Language/GraphQL/ValidateSpec.hs +++ b/tests/Language/GraphQL/ValidateSpec.hs @@ -92,6 +92,7 @@ petType = InterfaceType "Pet" Nothing [] subscriptionType :: ObjectType IO subscriptionType = ObjectType "Subscription" Nothing [] $ HashMap.fromList [ ("newMessage", newMessageResolver) + , ("disallowedSecondRootField", newMessageResolver) ] where newMessageField = Field Nothing (Out.NonNullObjectType messageType) mempty @@ -165,7 +166,8 @@ spec = |] expected = Error { message = - "Subscription sub must select only one top level field." + "Subscription \"sub\" must select only one top level \ + \field." , locations = [AST.Location 2 15] } in validate queryString `shouldContain` [expected] @@ -186,7 +188,8 @@ spec = |] expected = Error { message = - "Subscription sub must select only one top level field." + "Subscription \"sub\" must select only one top level \ + \field." , locations = [AST.Location 2 15] } in validate queryString `shouldContain` [expected] @@ -631,3 +634,32 @@ spec = , locations = [AST.Location 3 34] } in validate queryString `shouldBe` [expected] + + it "finds corresponding subscription fragment" $ + let queryString = [r| + subscription sub { + ...anotherSubscription + ...multipleSubscriptions + } + fragment multipleSubscriptions on Subscription { + newMessage { + body + } + disallowedSecondRootField { + sender + } + } + fragment anotherSubscription on Subscription { + newMessage { + body + sender + } + } + |] + expected = Error + { message = + "Subscription \"sub\" must select only one top level \ + \field." + , locations = [AST.Location 2 15] + } + in validate queryString `shouldBe` [expected]