Extract collectFields function

This commit is contained in:
Eugen Wissner 2020-11-11 08:49:45 +01:00
parent 445f33dcf3
commit f5209481aa
4 changed files with 47 additions and 31 deletions

View File

@ -6,6 +6,8 @@ The format is based on
and this project adheres to and this project adheres to
[Haskell Package Versioning Policy](https://pvp.haskell.org/). [Haskell Package Versioning Policy](https://pvp.haskell.org/).
## [Unreleased]
## [0.11.0.0] - 2020-11-07 ## [0.11.0.0] - 2020-11-07
### Changed ### Changed
- `AST.Document.Selection` wraps additional new types: `Field`, `FragmentSpread` - `AST.Document.Selection` wraps additional new types: `Field`, `FragmentSpread`
@ -400,6 +402,7 @@ and this project adheres to
### Added ### Added
- Data types for the GraphQL language. - Data types for the GraphQL language.
[Unreleased]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=master&rev_to=v0.11.0.0
[0.11.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.11.0.0&rev_to=v0.10.0.0 [0.11.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.11.0.0&rev_to=v0.10.0.0
[0.10.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.10.0.0&rev_to=v0.9.0.0 [0.10.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.10.0.0&rev_to=v0.9.0.0
[0.9.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.9.0.0&rev_to=v0.8.0.0 [0.9.0.0]: https://www.caraus.tech/projects/pub-graphql/repository/23/diff?rev=v0.9.0.0&rev_to=v0.8.0.0

View File

@ -123,8 +123,8 @@ executableDefinitionsRule = DefinitionRule $ \case
singleFieldSubscriptionsRule :: forall m. Rule m singleFieldSubscriptionsRule :: forall m. Rule m
singleFieldSubscriptionsRule = OperationDefinitionRule $ \case singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do
groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty groupedFieldSet <- collectFields rootFields
case HashSet.size groupedFieldSet of case length groupedFieldSet of
1 -> lift mempty 1 -> lift mempty
_ _
| Just name <- name' -> pure $ Error | Just name <- name' -> pure $ Error
@ -143,18 +143,26 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
where where
errorMessage = errorMessage =
"Anonymous Subscription must select only one top level field." "Anonymous Subscription must select only one top level field."
collectFields selectionSet = foldM forEach HashSet.empty selectionSet
collectFields :: forall m
. NonEmpty Full.Selection
-> ReaderT (Validation m) Seq [[Full.Field]]
collectFields selectionSet = evalStateT go HashSet.empty
where
go = groupSorted getFieldName <$> accumulateFields [] selectionSet
getFieldName (Full.Field alias name _ _ _ _)
| Just aliasedName <- alias = aliasedName
| otherwise = name
accumulateFields = foldM forEach
forEach accumulator = \case forEach accumulator = \case
Full.FieldSelection fieldSelection -> forField accumulator fieldSelection Full.FieldSelection fieldSelection -> forField accumulator fieldSelection
Full.FragmentSpreadSelection fragmentSelection -> Full.FragmentSpreadSelection fragmentSelection ->
forSpread accumulator fragmentSelection forSpread accumulator fragmentSelection
Full.InlineFragmentSelection fragmentSelection -> Full.InlineFragmentSelection fragmentSelection ->
forInline accumulator fragmentSelection forInline accumulator fragmentSelection
forField accumulator (Full.Field alias name _ directives' _ _) forField accumulator field@(Full.Field _ _ _ directives' _ _)
| any skip directives' = pure accumulator | any skip directives' = pure accumulator
| Just aliasedName <- alias = pure | otherwise = pure $ field : accumulator
$ HashSet.insert aliasedName accumulator
| otherwise = pure $ HashSet.insert name accumulator
forSpread accumulator (Full.FragmentSpread fragmentName directives' _) forSpread accumulator (Full.FragmentSpread fragmentName directives' _)
| any skip directives' = pure accumulator | any skip directives' = pure accumulator
| otherwise = do | otherwise = do
@ -166,14 +174,8 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
| any skip directives' = pure accumulator | any skip directives' = pure accumulator
| Just typeCondition <- maybeType = | Just typeCondition <- maybeType =
collectFromFragment typeCondition selections accumulator collectFromFragment typeCondition selections accumulator
| otherwise = HashSet.union accumulator | otherwise = accumulateFields accumulator selections
<$> collectFields selections collectFromFragment typeCondition selectionSet' accumulator = do
skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean True == argumentValue
skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean False == argumentValue
skip _ = False
collectFromFragment typeCondition selectionSet accumulator = do
types' <- lift $ asks $ Schema.types . schema types' <- lift $ asks $ Schema.types . schema
schema' <- lift $ asks schema schema' <- lift $ asks schema
case Type.lookupTypeCondition typeCondition types' of case Type.lookupTypeCondition typeCondition types' of
@ -181,15 +183,20 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
Just compositeType Just compositeType
| Just objectType <- Schema.subscription schema' | Just objectType <- Schema.subscription schema'
, True <- Type.doesFragmentTypeApply compositeType objectType -> , True <- Type.doesFragmentTypeApply compositeType objectType ->
HashSet.union accumulator <$> collectFields selectionSet accumulateFields accumulator selectionSet'
| otherwise -> pure accumulator | otherwise -> pure accumulator
collectFromSpread fragmentName accumulator = do collectFromSpread fragmentName accumulator = do
modify $ HashSet.insert fragmentName modify $ HashSet.insert fragmentName
ast' <- lift $ asks ast ast' <- lift $ asks ast
case findFragmentDefinition fragmentName ast' of case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator Nothing -> pure accumulator
Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) -> Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) ->
collectFromFragment typeCondition selectionSet accumulator collectFromFragment typeCondition selectionSet' accumulator
skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean True == argumentValue
skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean False == argumentValue
skip _ = False
-- | GraphQL allows a shorthand form for defining query operations when only -- | GraphQL allows a shorthand form for defining query operations when only
-- that one operation exists in the document. -- that one operation exists in the document.
@ -451,8 +458,7 @@ filterSelections applyFilter selections
noFragmentCyclesRule :: forall m. Rule m noFragmentCyclesRule :: forall m. Rule m
noFragmentCyclesRule = FragmentDefinitionRule $ \case noFragmentCyclesRule = FragmentDefinitionRule $ \case
Full.FragmentDefinition fragmentName _ _ selections location' -> do Full.FragmentDefinition fragmentName _ _ selections location' -> do
state <- evalStateT (collectFields selections) state <- evalStateT (collectCycles selections) (0, fragmentName)
(0, fragmentName)
let spreadPath = fst <$> sortBy (comparing snd) (HashMap.toList state) let spreadPath = fst <$> sortBy (comparing snd) (HashMap.toList state)
case reverse spreadPath of case reverse spreadPath of
x : _ | x == fragmentName -> pure $ Error x : _ | x == fragmentName -> pure $ Error
@ -467,10 +473,10 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case
} }
_ -> lift mempty _ -> lift mempty
where where
collectFields :: Traversable t collectCycles :: Traversable t
=> t Full.Selection => t Full.Selection
-> StateT (Int, Full.Name) (ReaderT (Validation m) Seq) (HashMap Full.Name Int) -> StateT (Int, Full.Name) (ReaderT (Validation m) Seq) (HashMap Full.Name Int)
collectFields selectionSet = foldM forEach HashMap.empty selectionSet collectCycles selectionSet = foldM forEach HashMap.empty selectionSet
forEach accumulator = \case forEach accumulator = \case
Full.FieldSelection fieldSelection -> forField accumulator fieldSelection Full.FieldSelection fieldSelection -> forField accumulator fieldSelection
Full.InlineFragmentSelection fragmentSelection -> Full.InlineFragmentSelection fragmentSelection ->
@ -487,15 +493,15 @@ noFragmentCyclesRule = FragmentDefinitionRule $ \case
then pure newAccumulator then pure newAccumulator
else collectFromSpread fragmentName newAccumulator else collectFromSpread fragmentName newAccumulator
forInline accumulator (Full.InlineFragment _ _ selections _) = forInline accumulator (Full.InlineFragment _ _ selections _) =
(accumulator <>) <$> collectFields selections (accumulator <>) <$> collectCycles selections
forField accumulator (Full.Field _ _ _ _ selections _) = forField accumulator (Full.Field _ _ _ _ selections _) =
(accumulator <>) <$> collectFields selections (accumulator <>) <$> collectCycles selections
collectFromSpread fragmentName accumulator = do collectFromSpread fragmentName accumulator = do
ast' <- lift $ asks ast ast' <- lift $ asks ast
case findFragmentDefinition fragmentName ast' of case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator Nothing -> pure accumulator
Just (Full.FragmentDefinition _ _ _ selections _) -> Just (Full.FragmentDefinition _ _ _ selections _) ->
(accumulator <>) <$> collectFields selections (accumulator <>) <$> collectCycles selections
findFragmentDefinition :: Text findFragmentDefinition :: Text
-> NonEmpty Full.Definition -> NonEmpty Full.Definition
@ -531,15 +537,22 @@ uniqueDirectiveNamesRule = DirectivesRule
extract (Full.Directive directiveName _ location') = extract (Full.Directive directiveName _ location') =
(directiveName, location') (directiveName, location')
filterDuplicates :: (a -> (Text, Full.Location)) -> String -> [a] -> Seq Error groupSorted :: forall a. (a -> Text) -> [a] -> [[a]]
groupSorted getName = groupBy equalByName . sortOn getName
where
equalByName lhs rhs = getName lhs == getName rhs
filterDuplicates :: forall a
. (a -> (Text, Full.Location))
-> String
-> [a]
-> Seq Error
filterDuplicates extract nodeType = Seq.fromList filterDuplicates extract nodeType = Seq.fromList
. fmap makeError . fmap makeError
. filter ((> 1) . length) . filter ((> 1) . length)
. groupBy equalByName . groupSorted getName
. sortOn getName
where where
getName = fst . extract getName = fst . extract
equalByName lhs rhs = getName lhs == getName rhs
makeError directives' = Error makeError directives' = Error
{ message = makeMessage $ head directives' { message = makeMessage $ head directives'
, locations = snd . extract <$> directives' , locations = snd . extract <$> directives'

View File

@ -1,4 +1,4 @@
resolver: lts-16.20 resolver: lts-16.21
packages: packages:
- . - .