summaryrefslogtreecommitdiff
path: root/src/Language/GraphQL/Validate/Rules.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/Language/GraphQL/Validate/Rules.hs')
-rw-r--r--src/Language/GraphQL/Validate/Rules.hs71
1 files changed, 42 insertions, 29 deletions
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs
index c67df1c..740131a 100644
--- a/src/Language/GraphQL/Validate/Rules.hs
+++ b/src/Language/GraphQL/Validate/Rules.hs
@@ -123,8 +123,8 @@ executableDefinitionsRule = DefinitionRule $ \case
singleFieldSubscriptionsRule :: forall m. Rule m
singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do
- groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty
- case HashSet.size groupedFieldSet of
+ groupedFieldSet <- collectFields rootFields
+ case length groupedFieldSet of
1 -> lift mempty
_
| Just name <- name' -> pure $ Error
@@ -143,18 +143,26 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
where
errorMessage =
"Anonymous Subscription must select only one top level field."
- collectFields selectionSet = foldM forEach HashSet.empty selectionSet
+
+collectFields :: forall m
+ . NonEmpty Full.Selection
+ -> ReaderT (Validation m) Seq [[Full.Field]]
+collectFields selectionSet = evalStateT go HashSet.empty
+ where
+ go = groupSorted getFieldName <$> accumulateFields [] selectionSet
+ getFieldName (Full.Field alias name _ _ _ _)
+ | Just aliasedName <- alias = aliasedName
+ | otherwise = name
+ accumulateFields = foldM forEach
forEach accumulator = \case
Full.FieldSelection fieldSelection -> forField accumulator fieldSelection
Full.FragmentSpreadSelection fragmentSelection ->
forSpread accumulator fragmentSelection
Full.InlineFragmentSelection fragmentSelection ->
forInline accumulator fragmentSelection
- forField accumulator (Full.Field alias name _ directives' _ _)
+ forField accumulator field@(Full.Field _ _ _ directives' _ _)
| any skip directives' = pure accumulator
- | Just aliasedName <- alias = pure
- $ HashSet.insert aliasedName accumulator
- | otherwise = pure $ HashSet.insert name accumulator
+ | otherwise = pure $ field : accumulator
forSpread accumulator (Full.FragmentSpread fragmentName directives' _)
| any skip directives' = pure accumulator
| otherwise = do
@@ -166,14 +174,8 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
| any skip directives' = pure accumulator
| Just typeCondition <- maybeType =
collectFromFragment typeCondition selections accumulator
- | otherwise = HashSet.union accumulator
- <$> collectFields selections
- skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
- Full.Boolean True == argumentValue
- skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
- Full.Boolean False == argumentValue
- skip _ = False
- collectFromFragment typeCondition selectionSet accumulator = do
+ | otherwise = accumulateFields accumulator selections
+ collectFromFragment typeCondition selectionSet' accumulator = do
types' <- lift $ asks $ Schema.types . schema
schema' <- lift $ asks schema
case Type.lookupTypeCondition typeCondition types' of
@@ -181,15 +183,20 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
Just compositeType
| Just objectType <- Schema.subscription schema'
, True <- Type.doesFragmentTypeApply compositeType objectType ->
- HashSet.union accumulator <$> collectFields selectionSet
+ accumulateFields accumulator selectionSet'
| otherwise -> pure accumulator
collectFromSpread fragmentName accumulator = do
modify $ HashSet.insert fragmentName
ast' <- lift $ asks ast
case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator
- Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) ->
- collectFromFragment typeCondition selectionSet accumulator
+ Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) ->
+ collectFromFragment typeCondition selectionSet' accumulator
+ skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
+ Full.Boolean True == argumentValue
+ skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
+ Full.Boolean False == argumentValue
+ skip _ = False
-- | GraphQL allows a short‐hand form for defining query operations when only
-- that one operation exists in the document.
@@ -451,8 +458,7 @@ filterSelections applyFilter selections
noFragmentCyclesRule :: forall m. Rule m
noFragmentCyclesRule = FragmentDefinitionRule $ \case
Full.FragmentDefinition fragmentName _ _ selections location' -> do
- state <- evalStateT (collectFields selections)
- (0, fragmentName)
+ state <- evalStateT (collectCycles selections) (0, fragmentName)
let spreadPath = fst <$> sortBy (comparing snd) (HashMap.toList state)
case reverse spreadPath of
x : _ | x == fragmentName -> pure $ Error
@@ -467,10 +473,10 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case
}
_ -> lift mempty
where
- collectFields :: Traversable t
+ collectCycles :: Traversable t
=> t Full.Selection
-> StateT (Int, Full.Name) (ReaderT (Validation m) Seq) (HashMap Full.Name Int)
- collectFields selectionSet = foldM forEach HashMap.empty selectionSet
+ collectCycles selectionSet = foldM forEach HashMap.empty selectionSet
forEach accumulator = \case
Full.FieldSelection fieldSelection -> forField accumulator fieldSelection
Full.InlineFragmentSelection fragmentSelection ->
@@ -487,15 +493,15 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case
then pure newAccumulator
else collectFromSpread fragmentName newAccumulator
forInline accumulator (Full.InlineFragment _ _ selections _) =
- (accumulator <>) <$> collectFields selections
+ (accumulator <>) <$> collectCycles selections
forField accumulator (Full.Field _ _ _ _ selections _) =
- (accumulator <>) <$> collectFields selections
+ (accumulator <>) <$> collectCycles selections
collectFromSpread fragmentName accumulator = do
ast' <- lift $ asks ast
case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator
Just (Full.FragmentDefinition _ _ _ selections _) ->
- (accumulator <>) <$> collectFields selections
+ (accumulator <>) <$> collectCycles selections
findFragmentDefinition :: Text
-> NonEmpty Full.Definition
@@ -531,15 +537,22 @@ uniqueDirectiveNamesRule = DirectivesRule
extract (Full.Directive directiveName _ location') =
(directiveName, location')
-filterDuplicates :: (a -> (Text, Full.Location)) -> String -> [a] -> Seq Error
+groupSorted :: forall a. (a -> Text) -> [a] -> [[a]]
+groupSorted getName = groupBy equalByName . sortOn getName
+ where
+ equalByName lhs rhs = getName lhs == getName rhs
+
+filterDuplicates :: forall a
+ . (a -> (Text, Full.Location))
+ -> String
+ -> [a]
+ -> Seq Error
filterDuplicates extract nodeType = Seq.fromList
. fmap makeError
. filter ((> 1) . length)
- . groupBy equalByName
- . sortOn getName
+ . groupSorted getName
where
getName = fst . extract
- equalByName lhs rhs = getName lhs == getName rhs
makeError directives' = Error
{ message = makeMessage $ head directives'
, locations = snd . extract <$> directives'