diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 3d66125..1c202fe 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -137,25 +137,28 @@ singleFieldSubscriptionsRule :: forall m. Rule m singleFieldSubscriptionsRule = OperationDefinitionRule $ \case Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty - case HashSet.size groupedFieldSet of - 1 -> lift mempty - _ - | Just name <- name' -> pure $ Error - { message = concat - [ "Subscription \"" - , Text.unpack name - , "\" must select only one top level field." - ] - , locations = [location'] - } - | otherwise -> pure $ Error - { message = errorMessage - , locations = [location'] - } + case HashSet.toList groupedFieldSet of + [rootName] + | Text.isPrefixOf "__" rootName -> makeError location' name' + "exactly one top level field, which must not be an introspection field." + | otherwise -> lift mempty + [] -> makeError location' name' "exactly one top level field." + _ -> makeError location' name' "only one top level field." _ -> lift mempty where - errorMessage = - "Anonymous Subscription must select only one top level field." + makeError location' (Just operationName) errorLine = pure $ Error + { message = concat + [ "Subscription \"" + , Text.unpack operationName + , "\" must select " + , errorLine + ] + , locations = [location'] + } + makeError location' Nothing errorLine = pure $ Error + { message = "Anonymous Subscription must select " <> errorLine + , locations = [location'] + } collectFields = foldM forEach HashSet.empty forEach accumulator = \case Full.FieldSelection fieldSelection -> forField accumulator fieldSelection diff --git a/tests/Language/GraphQL/Validate/RulesSpec.hs b/tests/Language/GraphQL/Validate/RulesSpec.hs index af3b38e..9ec74c3 100644 --- a/tests/Language/GraphQL/Validate/RulesSpec.hs +++ b/tests/Language/GraphQL/Validate/RulesSpec.hs @@ -18,7 +18,7 @@ import qualified Language.GraphQL.AST.DirectiveLocation as DirectiveLocation import qualified Language.GraphQL.Type.In as In import qualified Language.GraphQL.Type.Out as Out import Language.GraphQL.Validate -import Test.Hspec (Spec, context, describe, it, shouldBe, shouldContain, xit) +import Test.Hspec (Spec, context, describe, it, shouldBe, shouldContain) import Text.Megaparsec (parse, errorBundlePretty) petSchema :: Schema IO @@ -206,14 +206,14 @@ spec = } in validate queryString `shouldContain` [expected] - xit "rejects an introspection field as the subscription root" $ + it "rejects an introspection field as the subscription root" $ let queryString = "subscription sub {\n\ \ __typename\n\ \}" expected = Error { message = - "Subscription \"sub\" must select only one top \ - \level field." + "Subscription \"sub\" must select exactly one top \ + \level field, which must not be an introspection field." , locations = [AST.Location 1 1] } in validate queryString `shouldContain` [expected]