summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--CHANGELOG.md3
-rw-r--r--src/Language/GraphQL/Type/Internal.hs27
-rw-r--r--src/Language/GraphQL/Validate/Rules.hs237
-rw-r--r--tests/Language/GraphQL/ValidateSpec.hs78
4 files changed, 310 insertions, 35 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index bfbd79f..9854006 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,9 @@ and this project adheres to
[Haskell Package Versioning Policy](https://pvp.haskell.org/).
## [Unreleased]
+### Added
+- `Validate.Rules`:
+ - `overlappingFieldsCanBeMergedRule`
## [0.11.0.0] - 2020-11-07
### Changed
diff --git a/src/Language/GraphQL/Type/Internal.hs b/src/Language/GraphQL/Type/Internal.hs
index eb8489c..3861365 100644
--- a/src/Language/GraphQL/Type/Internal.hs
+++ b/src/Language/GraphQL/Type/Internal.hs
@@ -15,10 +15,12 @@ module Language.GraphQL.Type.Internal
, directives
, doesFragmentTypeApply
, instanceOf
+ , lookupCompositeField
, lookupInputType
, lookupTypeCondition
, lookupTypeField
, mutation
+ , outToComposite
, subscription
, query
, types
@@ -160,12 +162,16 @@ lookupInputType (Full.TypeNonNull (Full.NonNullTypeList nonNull)) types'
<$> lookupInputType nonNull types'
lookupTypeField :: forall a. Full.Name -> Out.Type a -> Maybe (Out.Field a)
-lookupTypeField fieldName = \case
- Out.ObjectBaseType objectType ->
- objectChild objectType
- Out.InterfaceBaseType interfaceType ->
- interfaceChild interfaceType
- Out.ListBaseType listType -> lookupTypeField fieldName listType
+lookupTypeField fieldName outputType =
+ outToComposite outputType >>= lookupCompositeField fieldName
+
+lookupCompositeField :: forall a
+ . Full.Name
+ -> CompositeType a
+ -> Maybe (Out.Field a)
+lookupCompositeField fieldName = \case
+ CompositeObjectType objectType -> objectChild objectType
+ CompositeInterfaceType interfaceType -> interfaceChild interfaceType
_ -> Nothing
where
objectChild (Out.ObjectType _ _ _ resolvers) =
@@ -174,3 +180,12 @@ lookupTypeField fieldName = \case
HashMap.lookup fieldName fields
resolverType (Out.ValueResolver objectField _) = objectField
resolverType (Out.EventStreamResolver objectField _ _) = objectField
+
+outToComposite :: forall a. Out.Type a -> Maybe (CompositeType a)
+outToComposite = \case
+ Out.ObjectBaseType objectType -> Just $ CompositeObjectType objectType
+ Out.InterfaceBaseType interfaceType ->
+ Just $ CompositeInterfaceType interfaceType
+ Out.UnionBaseType unionType -> Just $ CompositeUnionType unionType
+ Out.ListBaseType listType -> outToComposite listType
+ _ -> Nothing
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs
index 740131a..d17c1bc 100644
--- a/src/Language/GraphQL/Validate/Rules.hs
+++ b/src/Language/GraphQL/Validate/Rules.hs
@@ -24,6 +24,7 @@ module Language.GraphQL.Validate.Rules
, noUndefinedVariablesRule
, noUnusedFragmentsRule
, noUnusedVariablesRule
+ , overlappingFieldsCanBeMergedRule
, providedRequiredInputFieldsRule
, providedRequiredArgumentsRule
, scalarLeafsRule
@@ -40,17 +41,17 @@ module Language.GraphQL.Validate.Rules
import Control.Monad ((>=>), foldM)
import Control.Monad.Trans.Class (MonadTrans(..))
-import Control.Monad.Trans.Reader (ReaderT(..), asks, mapReaderT)
+import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT)
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
import Data.Bifunctor (first)
-import Data.Foldable (find, toList)
+import Data.Foldable (find, foldl', toList)
import qualified Data.HashMap.Strict as HashMap
import Data.HashMap.Strict (HashMap)
import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet
import Data.List (groupBy, sortBy, sortOn)
-import Data.Maybe (isNothing, mapMaybe)
-import Data.List.NonEmpty (NonEmpty)
+import Data.Maybe (fromMaybe, isNothing, mapMaybe)
+import Data.List.NonEmpty (NonEmpty(..))
import Data.Ord (comparing)
import Data.Sequence (Seq(..), (|>))
import qualified Data.Sequence as Seq
@@ -80,6 +81,7 @@ specifiedRules =
-- Fields
, fieldsOnCorrectTypeRule
, scalarLeafsRule
+ , overlappingFieldsCanBeMergedRule
-- Arguments.
, knownArgumentNamesRule
, uniqueArgumentNamesRule
@@ -123,8 +125,8 @@ executableDefinitionsRule = DefinitionRule $ \case
singleFieldSubscriptionsRule :: forall m. Rule m
singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do
- groupedFieldSet <- collectFields rootFields
- case length groupedFieldSet of
+ groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty
+ case HashSet.size groupedFieldSet of
1 -> lift mempty
_
| Just name <- name' -> pure $ Error
@@ -143,26 +145,18 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
where
errorMessage =
"Anonymous Subscription must select only one top level field."
-
-collectFields :: forall m
- . NonEmpty Full.Selection
- -> ReaderT (Validation m) Seq [[Full.Field]]
-collectFields selectionSet = evalStateT go HashSet.empty
- where
- go = groupSorted getFieldName <$> accumulateFields [] selectionSet
- getFieldName (Full.Field alias name _ _ _ _)
- | Just aliasedName <- alias = aliasedName
- | otherwise = name
- accumulateFields = foldM forEach
+ collectFields selectionSet = foldM forEach HashSet.empty selectionSet
forEach accumulator = \case
Full.FieldSelection fieldSelection -> forField accumulator fieldSelection
Full.FragmentSpreadSelection fragmentSelection ->
forSpread accumulator fragmentSelection
Full.InlineFragmentSelection fragmentSelection ->
forInline accumulator fragmentSelection
- forField accumulator field@(Full.Field _ _ _ directives' _ _)
+ forField accumulator (Full.Field alias name _ directives' _ _)
| any skip directives' = pure accumulator
- | otherwise = pure $ field : accumulator
+ | Just aliasedName <- alias = pure
+ $ HashSet.insert aliasedName accumulator
+ | otherwise = pure $ HashSet.insert name accumulator
forSpread accumulator (Full.FragmentSpread fragmentName directives' _)
| any skip directives' = pure accumulator
| otherwise = do
@@ -174,8 +168,14 @@ collectFields selectionSet = evalStateT go HashSet.empty
| any skip directives' = pure accumulator
| Just typeCondition <- maybeType =
collectFromFragment typeCondition selections accumulator
- | otherwise = accumulateFields accumulator selections
- collectFromFragment typeCondition selectionSet' accumulator = do
+ | otherwise = HashSet.union accumulator
+ <$> collectFields selections
+ skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
+ Full.Boolean True == argumentValue
+ skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
+ Full.Boolean False == argumentValue
+ skip _ = False
+ collectFromFragment typeCondition selectionSet accumulator = do
types' <- lift $ asks $ Schema.types . schema
schema' <- lift $ asks schema
case Type.lookupTypeCondition typeCondition types' of
@@ -183,20 +183,15 @@ collectFields selectionSet = evalStateT go HashSet.empty
Just compositeType
| Just objectType <- Schema.subscription schema'
, True <- Type.doesFragmentTypeApply compositeType objectType ->
- accumulateFields accumulator selectionSet'
+ HashSet.union accumulator <$> collectFields selectionSet
| otherwise -> pure accumulator
collectFromSpread fragmentName accumulator = do
modify $ HashSet.insert fragmentName
ast' <- lift $ asks ast
case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator
- Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) ->
- collectFromFragment typeCondition selectionSet' accumulator
- skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
- Full.Boolean True == argumentValue
- skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
- Full.Boolean False == argumentValue
- skip _ = False
+ Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) ->
+ collectFromFragment typeCondition selectionSet accumulator
-- | GraphQL allows a short‐hand form for defining query operations when only
-- that one operation exists in the document.
@@ -1026,3 +1021,187 @@ providedRequiredInputFieldsRule = ValueRule go constGo
, Text.unpack typeName
, "\" is required, but it was not provided."
]
+
+-- | If multiple field selections with the same response names are encountered
+-- during execution, the field and arguments to execute and the resulting value
+-- should be unambiguous. Therefore any two field selections which might both be
+-- encountered for the same object are only valid if they are equivalent.
+--
+-- For simple hand‐written GraphQL, this rule is obviously a clear developer
+-- error, however nested fragments can make this difficult to detect manually.
+overlappingFieldsCanBeMergedRule :: Rule m
+overlappingFieldsCanBeMergedRule = OperationDefinitionRule $ \case
+ Full.SelectionSet selectionSet _ -> do
+ schema' <- asks schema
+ go (toList selectionSet)
+ $ Type.CompositeObjectType
+ $ Schema.query schema'
+ Full.OperationDefinition operationType _ _ _ selectionSet _ -> do
+ schema' <- asks schema
+ let root = go (toList selectionSet) . Type.CompositeObjectType
+ case operationType of
+ Full.Query -> root $ Schema.query schema'
+ Full.Mutation
+ | Just objectType <- Schema.mutation schema' -> root objectType
+ Full.Subscription
+ | Just objectType <- Schema.mutation schema' -> root objectType
+ _ -> lift mempty
+ where
+ go selectionSet selectionType = do
+ fieldTuples <- evalStateT (collectFields selectionType selectionSet) HashSet.empty
+ fieldsInSetCanMerge fieldTuples
+ fieldsInSetCanMerge :: forall m
+ . HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m))
+ -> ReaderT (Validation m) Seq Error
+ fieldsInSetCanMerge fieldTuples = do
+ validation <- ask
+ let (lonely, paired) = flattenPairs fieldTuples
+ let reader = flip runReaderT validation
+ lift $ foldMap (reader . visitLonelyFields) lonely
+ <> foldMap (reader . forEachFieldTuple) paired
+ forEachFieldTuple :: forall m
+ . (FieldInfo m, FieldInfo m)
+ -> ReaderT (Validation m) Seq Error
+ forEachFieldTuple (fieldA, fieldB) =
+ case (parent fieldA, parent fieldB) of
+ (parentA@Type.CompositeObjectType{}, parentB@Type.CompositeObjectType{})
+ | parentA /= parentB -> sameResponseShape fieldA fieldB
+ _ -> mapReaderT (checkEquality (node fieldA) (node fieldB))
+ $ sameResponseShape fieldA fieldB
+ checkEquality fieldA fieldB Seq.Empty
+ | Full.Field _ fieldNameA _ _ _ _ <- fieldA
+ , Full.Field _ fieldNameB _ _ _ _ <- fieldB
+ , fieldNameA /= fieldNameB = pure $ makeError fieldA fieldB
+ | Full.Field _ fieldNameA argumentsA _ _ locationA <- fieldA
+ , Full.Field _ _ argumentsB _ _ locationB <- fieldB
+ , argumentsA /= argumentsB =
+ let message = concat
+ [ "Fields \""
+ , Text.unpack fieldNameA
+ , "\" conflict because they have different arguments. Use "
+ , "different aliases on the fields to fetch both if this "
+ , "was intentional."
+ ]
+ in pure $ Error message [locationB, locationA]
+ checkEquality _ _ previousErrors = previousErrors
+ visitLonelyFields FieldInfo{..} =
+ let Full.Field _ _ _ _ subSelections _ = node
+ compositeFieldType = Type.outToComposite type'
+ in maybe (lift Seq.empty) (go subSelections) compositeFieldType
+ sameResponseShape :: forall m
+ . FieldInfo m
+ -> FieldInfo m
+ -> ReaderT (Validation m) Seq Error
+ sameResponseShape fieldA fieldB =
+ let Full.Field _ _ _ _ selectionsA _ = node fieldA
+ Full.Field _ _ _ _ selectionsB _ = node fieldB
+ in case unwrapTypes (type' fieldA) (type' fieldB) of
+ Left True -> lift mempty
+ Right (compositeA, compositeB) -> do
+ validation <- ask
+ let collectFields' composite = flip runReaderT validation
+ . flip evalStateT HashSet.empty
+ . collectFields composite
+ let collectA = collectFields' compositeA selectionsA
+ let collectB = collectFields' compositeB selectionsB
+ fieldsInSetCanMerge
+ $ foldl' (HashMap.unionWith (<>)) HashMap.empty
+ $ collectA <> collectB
+ _ -> pure $ makeError (node fieldA) (node fieldB)
+ makeError fieldA fieldB =
+ let Full.Field aliasA fieldNameA _ _ _ locationA = fieldA
+ Full.Field _ fieldNameB _ _ _ locationB = fieldB
+ message = concat
+ [ "Fields \""
+ , Text.unpack (fromMaybe fieldNameA aliasA)
+ , "\" conflict because \""
+ , Text.unpack fieldNameB
+ , "\" and \""
+ , Text.unpack fieldNameA
+ , "\" are different fields. Use different aliases on the fields "
+ , "to fetch both if this was intentional."
+ ]
+ in Error message [locationB, locationA]
+ unwrapTypes typeA@Out.ScalarBaseType{} typeB@Out.ScalarBaseType{} =
+ Left $ typeA == typeB
+ unwrapTypes typeA@Out.EnumBaseType{} typeB@Out.EnumBaseType{} =
+ Left $ typeA == typeB
+ unwrapTypes (Out.ListType listA) (Out.ListType listB) =
+ unwrapTypes listA listB
+ unwrapTypes (Out.NonNullListType listA) (Out.NonNullListType listB) =
+ unwrapTypes listA listB
+ unwrapTypes typeA typeB
+ | Out.isNonNullType typeA == Out.isNonNullType typeB
+ , Just compositeA <- Type.outToComposite typeA
+ , Just compositeB <- Type.outToComposite typeB =
+ Right (compositeA, compositeB)
+ | otherwise = Left False
+ flattenPairs :: forall m
+ . HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m))
+ -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m))
+ flattenPairs xs = HashMap.foldr splitSingleFields (Seq.empty, Seq.empty)
+ $ foldr lookupTypeField [] <$> xs
+ splitSingleFields :: forall m
+ . [FieldInfo m]
+ -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m))
+ -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m))
+ splitSingleFields [head'] (fields, pairList) = (fields |> head', pairList)
+ splitSingleFields xs (fields, pairList) = (fields, pairs pairList xs)
+ lookupTypeField (field, parentType) accumulator =
+ let Full.Field _ fieldName _ _ _ _ = field
+ in case Type.lookupCompositeField fieldName parentType of
+ Nothing -> accumulator
+ Just (Out.Field _ typeField _) ->
+ FieldInfo field typeField parentType : accumulator
+ pairs :: forall m
+ . Seq (FieldInfo m, FieldInfo m)
+ -> [FieldInfo m]
+ -> Seq (FieldInfo m, FieldInfo m)
+ pairs accumulator [] = accumulator
+ pairs accumulator (fieldA : fields) =
+ pair fieldA (pairs accumulator fields) fields
+ pair _ accumulator [] = accumulator
+ pair field accumulator (fieldA : fields) =
+ pair field accumulator fields |> (field, fieldA)
+ collectFields objectType = accumulateFields objectType mempty
+ accumulateFields = foldM . forEach
+ forEach parentType accumulator = \case
+ Full.FieldSelection fieldSelection ->
+ forField parentType accumulator fieldSelection
+ Full.FragmentSpreadSelection fragmentSelection ->
+ forSpread accumulator fragmentSelection
+ Full.InlineFragmentSelection fragmentSelection ->
+ forInline parentType accumulator fragmentSelection
+ forField parentType accumulator field@(Full.Field alias fieldName _ _ _ _) =
+ let key = fromMaybe fieldName alias
+ value = (field, parentType) :| []
+ in pure $ HashMap.insertWith (<>) key value accumulator
+ forSpread accumulator (Full.FragmentSpread fragmentName _ _) = do
+ inVisitetFragments <- gets $ HashSet.member fragmentName
+ if inVisitetFragments
+ then pure accumulator
+ else collectFromSpread fragmentName accumulator
+ forInline parentType accumulator = \case
+ Full.InlineFragment maybeType _ selections _
+ | Just typeCondition <- maybeType ->
+ collectFromFragment typeCondition selections accumulator
+ | otherwise -> accumulateFields parentType accumulator $ toList selections
+ collectFromFragment typeCondition selectionSet' accumulator = do
+ types' <- lift $ asks $ Schema.types . schema
+ case Type.lookupTypeCondition typeCondition types' of
+ Nothing -> pure accumulator
+ Just compositeType ->
+ accumulateFields compositeType accumulator $ toList selectionSet'
+ collectFromSpread fragmentName accumulator = do
+ modify $ HashSet.insert fragmentName
+ ast' <- lift $ asks ast
+ case findFragmentDefinition fragmentName ast' of
+ Nothing -> pure accumulator
+ Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) ->
+ collectFromFragment typeCondition selectionSet' accumulator
+
+data FieldInfo m = FieldInfo
+ { node :: Full.Field
+ , type' :: Out.Type m
+ , parent :: Type.CompositeType m
+ }
diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs
index 318045c..97ca2e9 100644
--- a/tests/Language/GraphQL/ValidateSpec.hs
+++ b/tests/Language/GraphQL/ValidateSpec.hs
@@ -663,3 +663,81 @@ spec =
, locations = [AST.Location 2 15]
}
in validate queryString `shouldBe` [expected]
+
+ it "fails to merge fields of mismatching types" $
+ let queryString = [r|
+ {
+ dog {
+ name: nickname
+ name
+ }
+ }
+ |]
+ expected = Error
+ { message =
+ "Fields \"name\" conflict because \"nickname\" and \
+ \\"name\" are different fields. Use different aliases \
+ \on the fields to fetch both if this was intentional."
+ , locations = [AST.Location 4 19, AST.Location 5 19]
+ }
+ in validate queryString `shouldBe` [expected]
+
+ it "fails if the arguments of the same field don't match" $
+ let queryString = [r|
+ {
+ dog {
+ doesKnowCommand(dogCommand: SIT)
+ doesKnowCommand(dogCommand: HEEL)
+ }
+ }
+ |]
+ expected = Error
+ { message =
+ "Fields \"doesKnowCommand\" conflict because they have \
+ \different arguments. Use different aliases on the \
+ \fields to fetch both if this was intentional."
+ , locations = [AST.Location 4 19, AST.Location 5 19]
+ }
+ in validate queryString `shouldBe` [expected]
+
+ it "fails to merge same-named field and alias" $
+ let queryString = [r|
+ {
+ dog {
+ doesKnowCommand(dogCommand: SIT)
+ doesKnowCommand: isHousetrained(atOtherHomes: true)
+ }
+ }
+ |]
+ expected = Error
+ { message =
+ "Fields \"doesKnowCommand\" conflict because \
+ \\"doesKnowCommand\" and \"isHousetrained\" are \
+ \different fields. Use different aliases on the fields \
+ \to fetch both if this was intentional."
+ , locations = [AST.Location 4 19, AST.Location 5 19]
+ }
+ in validate queryString `shouldBe` [expected]
+
+ it "looks for fields after a successfully merged field pair" $
+ let queryString = [r|
+ {
+ dog {
+ name
+ doesKnowCommand(dogCommand: SIT)
+ }
+ dog {
+ name
+ doesKnowCommand: isHousetrained(atOtherHomes: true)
+ }
+ }
+ |]
+ expected = Error
+ { message =
+ "Fields \"doesKnowCommand\" conflict because \
+ \\"doesKnowCommand\" and \"isHousetrained\" are \
+ \different fields. Use different aliases on the fields \
+ \to fetch both if this was intentional."
+ , locations = [AST.Location 5 19, AST.Location 9 19]
+ }
+ in validate queryString `shouldBe` [expected]