From 7ea76865e628b348fb8d5089aed9cc300cc653be Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Sun, 1 Dec 2024 21:47:29 +0100 Subject: Validate the subscription root MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …not to be an introspection field. --- src/Language/GraphQL/Validate/Rules.hs | 37 ++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 17 deletions(-) (limited to 'src/Language/GraphQL') diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 3d66125..1c202fe 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -137,25 +137,28 @@ singleFieldSubscriptionsRule :: forall m. Rule m singleFieldSubscriptionsRule = OperationDefinitionRule $ \case Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty - case HashSet.size groupedFieldSet of - 1 -> lift mempty - _ - | Just name <- name' -> pure $ Error - { message = concat - [ "Subscription \"" - , Text.unpack name - , "\" must select only one top level field." - ] - , locations = [location'] - } - | otherwise -> pure $ Error - { message = errorMessage - , locations = [location'] - } + case HashSet.toList groupedFieldSet of + [rootName] + | Text.isPrefixOf "__" rootName -> makeError location' name' + "exactly one top level field, which must not be an introspection field." + | otherwise -> lift mempty + [] -> makeError location' name' "exactly one top level field." + _ -> makeError location' name' "only one top level field." _ -> lift mempty where - errorMessage = - "Anonymous Subscription must select only one top level field." + makeError location' (Just operationName) errorLine = pure $ Error + { message = concat + [ "Subscription \"" + , Text.unpack operationName + , "\" must select " + , errorLine + ] + , locations = [location'] + } + makeError location' Nothing errorLine = pure $ Error + { message = "Anonymous Subscription must select " <> errorLine + , locations = [location'] + } collectFields = foldM forEach HashSet.empty forEach accumulator = \case Full.FieldSelection fieldSelection -> forField accumulator fieldSelection -- cgit v1.2.3