From f5209481aa28fdb5dcd92589839dab0f1cb8d1b9 Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Wed, 11 Nov 2020 08:49:45 +0100 Subject: [PATCH] Extract collectFields function --- CHANGELOG.md | 3 ++ src/Language/GraphQL/Validate.hs | 2 +- src/Language/GraphQL/Validate/Rules.hs | 71 +++++++++++++++----------- stack.yaml | 2 +- 4 files changed, 47 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6c3cd40..bfbd79f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,8 @@ The format is based on and this project adheres to [Haskell Package Versioning Policy](https://pvp.haskell.org/). +## [Unreleased] + ## [0.11.0.0] - 2020-11-07 ### Changed - `AST.Document.Selection` wraps additional new types: `Field`, `FragmentSpread` @@ -400,6 +402,7 @@ and this project adheres to ### Added - Data types for the GraphQL language. +[Unreleased]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=master&rev_to=v0.11.0.0 [0.11.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.11.0.0&rev_to=v0.10.0.0 [0.10.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.10.0.0&rev_to=v0.9.0.0 [0.9.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.9.0.0&rev_to=v0.8.0.0 diff --git a/src/Language/GraphQL/Validate.hs b/src/Language/GraphQL/Validate.hs index 277f84d..ea72018 100644 --- a/src/Language/GraphQL/Validate.hs +++ b/src/Language/GraphQL/Validate.hs @@ -210,7 +210,7 @@ typeDefinition context rule = \case directives context rule scalarLocation directives' Full.ObjectTypeDefinition _ _ _ directives' fields -> directives context rule objectLocation directives' - >< foldMap (fieldDefinition context rule) fields + >< foldMap (fieldDefinition context rule) fields Full.InterfaceTypeDefinition _ _ directives' fields -> directives context rule interfaceLocation directives' >< foldMap (fieldDefinition context rule) fields diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index c67df1c..740131a 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -123,8 +123,8 @@ executableDefinitionsRule = DefinitionRule $ \case singleFieldSubscriptionsRule :: forall m. Rule m singleFieldSubscriptionsRule = OperationDefinitionRule $ \case Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do - groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty - case HashSet.size groupedFieldSet of + groupedFieldSet <- collectFields rootFields + case length groupedFieldSet of 1 -> lift mempty _ | Just name <- name' -> pure $ Error @@ -143,18 +143,26 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case where errorMessage = "Anonymous Subscription must select only one top level field." - collectFields selectionSet = foldM forEach HashSet.empty selectionSet + +collectFields :: forall m + . NonEmpty Full.Selection + -> ReaderT (Validation m) Seq [[Full.Field]] +collectFields selectionSet = evalStateT go HashSet.empty + where + go = groupSorted getFieldName <$> accumulateFields [] selectionSet + getFieldName (Full.Field alias name _ _ _ _) + | Just aliasedName <- alias = aliasedName + | otherwise = name + accumulateFields = foldM forEach forEach accumulator = \case Full.FieldSelection fieldSelection -> forField accumulator fieldSelection Full.FragmentSpreadSelection fragmentSelection -> forSpread accumulator fragmentSelection Full.InlineFragmentSelection fragmentSelection -> forInline accumulator fragmentSelection - forField accumulator (Full.Field alias name _ directives' _ _) + forField accumulator field@(Full.Field _ _ _ directives' _ _) | any skip directives' = pure accumulator - | Just aliasedName <- alias = pure - $ HashSet.insert aliasedName accumulator - | otherwise = pure $ HashSet.insert name accumulator + | otherwise = pure $ field : accumulator forSpread accumulator (Full.FragmentSpread fragmentName directives' _) | any skip directives' = pure accumulator | otherwise = do @@ -166,14 +174,8 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case | any skip directives' = pure accumulator | Just typeCondition <- maybeType = collectFromFragment typeCondition selections accumulator - | otherwise = HashSet.union accumulator - <$> collectFields selections - skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) = - Full.Boolean True == argumentValue - skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) = - Full.Boolean False == argumentValue - skip _ = False - collectFromFragment typeCondition selectionSet accumulator = do + | otherwise = accumulateFields accumulator selections + collectFromFragment typeCondition selectionSet' accumulator = do types' <- lift $ asks $ Schema.types . schema schema' <- lift $ asks schema case Type.lookupTypeCondition typeCondition types' of @@ -181,15 +183,20 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case Just compositeType | Just objectType <- Schema.subscription schema' , True <- Type.doesFragmentTypeApply compositeType objectType -> - HashSet.union accumulator <$> collectFields selectionSet + accumulateFields accumulator selectionSet' | otherwise -> pure accumulator collectFromSpread fragmentName accumulator = do modify $ HashSet.insert fragmentName ast' <- lift $ asks ast case findFragmentDefinition fragmentName ast' of Nothing -> pure accumulator - Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) -> - collectFromFragment typeCondition selectionSet accumulator + Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) -> + collectFromFragment typeCondition selectionSet' accumulator + skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) = + Full.Boolean True == argumentValue + skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) = + Full.Boolean False == argumentValue + skip _ = False -- | GraphQL allows a short‐hand form for defining query operations when only -- that one operation exists in the document. @@ -451,8 +458,7 @@ filterSelections applyFilter selections noFragmentCyclesRule :: forall m. Rule m noFragmentCyclesRule = FragmentDefinitionRule $ \case Full.FragmentDefinition fragmentName _ _ selections location' -> do - state <- evalStateT (collectFields selections) - (0, fragmentName) + state <- evalStateT (collectCycles selections) (0, fragmentName) let spreadPath = fst <$> sortBy (comparing snd) (HashMap.toList state) case reverse spreadPath of x : _ | x == fragmentName -> pure $ Error @@ -467,10 +473,10 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case } _ -> lift mempty where - collectFields :: Traversable t + collectCycles :: Traversable t => t Full.Selection -> StateT (Int, Full.Name) (ReaderT (Validation m) Seq) (HashMap Full.Name Int) - collectFields selectionSet = foldM forEach HashMap.empty selectionSet + collectCycles selectionSet = foldM forEach HashMap.empty selectionSet forEach accumulator = \case Full.FieldSelection fieldSelection -> forField accumulator fieldSelection Full.InlineFragmentSelection fragmentSelection -> @@ -487,15 +493,15 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case then pure newAccumulator else collectFromSpread fragmentName newAccumulator forInline accumulator (Full.InlineFragment _ _ selections _) = - (accumulator <>) <$> collectFields selections + (accumulator <>) <$> collectCycles selections forField accumulator (Full.Field _ _ _ _ selections _) = - (accumulator <>) <$> collectFields selections + (accumulator <>) <$> collectCycles selections collectFromSpread fragmentName accumulator = do ast' <- lift $ asks ast case findFragmentDefinition fragmentName ast' of Nothing -> pure accumulator Just (Full.FragmentDefinition _ _ _ selections _) -> - (accumulator <>) <$> collectFields selections + (accumulator <>) <$> collectCycles selections findFragmentDefinition :: Text -> NonEmpty Full.Definition @@ -531,15 +537,22 @@ uniqueDirectiveNamesRule = DirectivesRule extract (Full.Directive directiveName _ location') = (directiveName, location') -filterDuplicates :: (a -> (Text, Full.Location)) -> String -> [a] -> Seq Error +groupSorted :: forall a. (a -> Text) -> [a] -> [[a]] +groupSorted getName = groupBy equalByName . sortOn getName + where + equalByName lhs rhs = getName lhs == getName rhs + +filterDuplicates :: forall a + . (a -> (Text, Full.Location)) + -> String + -> [a] + -> Seq Error filterDuplicates extract nodeType = Seq.fromList . fmap makeError . filter ((> 1) . length) - . groupBy equalByName - . sortOn getName + . groupSorted getName where getName = fst . extract - equalByName lhs rhs = getName lhs == getName rhs makeError directives' = Error { message = makeMessage $ head directives' , locations = snd . extract <$> directives' diff --git a/stack.yaml b/stack.yaml index 170fb52..7346efb 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,4 +1,4 @@ -resolver: lts-16.20 +resolver: lts-16.21 packages: - .