From 1f4eb6fb9bf847401b158d83516bd07650353f25 Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Sun, 15 Nov 2020 10:11:09 +0100 Subject: [PATCH] Implement basic "Field Selection Merging" rule --- CHANGELOG.md | 3 + src/Language/GraphQL/Type/Internal.hs | 27 ++- src/Language/GraphQL/Validate/Rules.hs | 237 ++++++++++++++++++++++--- tests/Language/GraphQL/ValidateSpec.hs | 78 ++++++++ 4 files changed, 310 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bfbd79f..9854006 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Haskell Package Versioning Policy](https://pvp.haskell.org/). ## [Unreleased] +### Added +- `Validate.Rules`: + - `overlappingFieldsCanBeMergedRule` ## [0.11.0.0] - 2020-11-07 ### Changed diff --git a/src/Language/GraphQL/Type/Internal.hs b/src/Language/GraphQL/Type/Internal.hs index eb8489c..3861365 100644 --- a/src/Language/GraphQL/Type/Internal.hs +++ b/src/Language/GraphQL/Type/Internal.hs @@ -15,10 +15,12 @@ module Language.GraphQL.Type.Internal , directives , doesFragmentTypeApply , instanceOf + , lookupCompositeField , lookupInputType , lookupTypeCondition , lookupTypeField , mutation + , outToComposite , subscription , query , types @@ -160,12 +162,16 @@ lookupInputType (Full.TypeNonNull (Full.NonNullTypeList nonNull)) types' <$> lookupInputType nonNull types' lookupTypeField :: forall a. Full.Name -> Out.Type a -> Maybe (Out.Field a) -lookupTypeField fieldName = \case - Out.ObjectBaseType objectType -> - objectChild objectType - Out.InterfaceBaseType interfaceType -> - interfaceChild interfaceType - Out.ListBaseType listType -> lookupTypeField fieldName listType +lookupTypeField fieldName outputType = + outToComposite outputType >>= lookupCompositeField fieldName + +lookupCompositeField :: forall a + . Full.Name + -> CompositeType a + -> Maybe (Out.Field a) +lookupCompositeField fieldName = \case + CompositeObjectType objectType -> objectChild objectType + CompositeInterfaceType interfaceType -> interfaceChild interfaceType _ -> Nothing where objectChild (Out.ObjectType _ _ _ resolvers) = @@ -174,3 +180,12 @@ lookupTypeField fieldName = \case HashMap.lookup fieldName fields resolverType (Out.ValueResolver objectField _) = objectField resolverType (Out.EventStreamResolver objectField _ _) = objectField + +outToComposite :: forall a. Out.Type a -> Maybe (CompositeType a) +outToComposite = \case + Out.ObjectBaseType objectType -> Just $ CompositeObjectType objectType + Out.InterfaceBaseType interfaceType -> + Just $ CompositeInterfaceType interfaceType + Out.UnionBaseType unionType -> Just $ CompositeUnionType unionType + Out.ListBaseType listType -> outToComposite listType + _ -> Nothing diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 740131a..d17c1bc 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -24,6 +24,7 @@ module Language.GraphQL.Validate.Rules , noUndefinedVariablesRule , noUnusedFragmentsRule , noUnusedVariablesRule + , overlappingFieldsCanBeMergedRule , providedRequiredInputFieldsRule , providedRequiredArgumentsRule , scalarLeafsRule @@ -40,17 +41,17 @@ module Language.GraphQL.Validate.Rules import Control.Monad ((>=>), foldM) import Control.Monad.Trans.Class (MonadTrans(..)) -import Control.Monad.Trans.Reader (ReaderT(..), asks, mapReaderT) +import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) import Data.Bifunctor (first) -import Data.Foldable (find, toList) +import Data.Foldable (find, foldl', toList) import qualified Data.HashMap.Strict as HashMap import Data.HashMap.Strict (HashMap) import Data.HashSet (HashSet) import qualified Data.HashSet as HashSet import Data.List (groupBy, sortBy, sortOn) -import Data.Maybe (isNothing, mapMaybe) -import Data.List.NonEmpty (NonEmpty) +import Data.Maybe (fromMaybe, isNothing, mapMaybe) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Ord (comparing) import Data.Sequence (Seq(..), (|>)) import qualified Data.Sequence as Seq @@ -80,6 +81,7 @@ specifiedRules = -- Fields , fieldsOnCorrectTypeRule , scalarLeafsRule + , overlappingFieldsCanBeMergedRule -- Arguments. , knownArgumentNamesRule , uniqueArgumentNamesRule @@ -123,8 +125,8 @@ executableDefinitionsRule = DefinitionRule $ \case singleFieldSubscriptionsRule :: forall m. Rule m singleFieldSubscriptionsRule = OperationDefinitionRule $ \case Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do - groupedFieldSet <- collectFields rootFields - case length groupedFieldSet of + groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty + case HashSet.size groupedFieldSet of 1 -> lift mempty _ | Just name <- name' -> pure $ Error @@ -143,26 +145,18 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case where errorMessage = "Anonymous Subscription must select only one top level field." - -collectFields :: forall m - . NonEmpty Full.Selection - -> ReaderT (Validation m) Seq [[Full.Field]] -collectFields selectionSet = evalStateT go HashSet.empty - where - go = groupSorted getFieldName <$> accumulateFields [] selectionSet - getFieldName (Full.Field alias name _ _ _ _) - | Just aliasedName <- alias = aliasedName - | otherwise = name - accumulateFields = foldM forEach + collectFields selectionSet = foldM forEach HashSet.empty selectionSet forEach accumulator = \case Full.FieldSelection fieldSelection -> forField accumulator fieldSelection Full.FragmentSpreadSelection fragmentSelection -> forSpread accumulator fragmentSelection Full.InlineFragmentSelection fragmentSelection -> forInline accumulator fragmentSelection - forField accumulator field@(Full.Field _ _ _ directives' _ _) + forField accumulator (Full.Field alias name _ directives' _ _) | any skip directives' = pure accumulator - | otherwise = pure $ field : accumulator + | Just aliasedName <- alias = pure + $ HashSet.insert aliasedName accumulator + | otherwise = pure $ HashSet.insert name accumulator forSpread accumulator (Full.FragmentSpread fragmentName directives' _) | any skip directives' = pure accumulator | otherwise = do @@ -174,8 +168,14 @@ collectFields selectionSet = evalStateT go HashSet.empty | any skip directives' = pure accumulator | Just typeCondition <- maybeType = collectFromFragment typeCondition selections accumulator - | otherwise = accumulateFields accumulator selections - collectFromFragment typeCondition selectionSet' accumulator = do + | otherwise = HashSet.union accumulator + <$> collectFields selections + skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) = + Full.Boolean True == argumentValue + skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) = + Full.Boolean False == argumentValue + skip _ = False + collectFromFragment typeCondition selectionSet accumulator = do types' <- lift $ asks $ Schema.types . schema schema' <- lift $ asks schema case Type.lookupTypeCondition typeCondition types' of @@ -183,20 +183,15 @@ collectFields selectionSet = evalStateT go HashSet.empty Just compositeType | Just objectType <- Schema.subscription schema' , True <- Type.doesFragmentTypeApply compositeType objectType -> - accumulateFields accumulator selectionSet' + HashSet.union accumulator <$> collectFields selectionSet | otherwise -> pure accumulator collectFromSpread fragmentName accumulator = do modify $ HashSet.insert fragmentName ast' <- lift $ asks ast case findFragmentDefinition fragmentName ast' of Nothing -> pure accumulator - Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) -> - collectFromFragment typeCondition selectionSet' accumulator - skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) = - Full.Boolean True == argumentValue - skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) = - Full.Boolean False == argumentValue - skip _ = False + Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) -> + collectFromFragment typeCondition selectionSet accumulator -- | GraphQL allows a short‐hand form for defining query operations when only -- that one operation exists in the document. @@ -1026,3 +1021,187 @@ providedRequiredInputFieldsRule = ValueRule go constGo , Text.unpack typeName , "\" is required, but it was not provided." ] + +-- | If multiple field selections with the same response names are encountered +-- during execution, the field and arguments to execute and the resulting value +-- should be unambiguous. Therefore any two field selections which might both be +-- encountered for the same object are only valid if they are equivalent. +-- +-- For simple hand‐written GraphQL, this rule is obviously a clear developer +-- error, however nested fragments can make this difficult to detect manually. +overlappingFieldsCanBeMergedRule :: Rule m +overlappingFieldsCanBeMergedRule = OperationDefinitionRule $ \case + Full.SelectionSet selectionSet _ -> do + schema' <- asks schema + go (toList selectionSet) + $ Type.CompositeObjectType + $ Schema.query schema' + Full.OperationDefinition operationType _ _ _ selectionSet _ -> do + schema' <- asks schema + let root = go (toList selectionSet) . Type.CompositeObjectType + case operationType of + Full.Query -> root $ Schema.query schema' + Full.Mutation + | Just objectType <- Schema.mutation schema' -> root objectType + Full.Subscription + | Just objectType <- Schema.mutation schema' -> root objectType + _ -> lift mempty + where + go selectionSet selectionType = do + fieldTuples <- evalStateT (collectFields selectionType selectionSet) HashSet.empty + fieldsInSetCanMerge fieldTuples + fieldsInSetCanMerge :: forall m + . HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m)) + -> ReaderT (Validation m) Seq Error + fieldsInSetCanMerge fieldTuples = do + validation <- ask + let (lonely, paired) = flattenPairs fieldTuples + let reader = flip runReaderT validation + lift $ foldMap (reader . visitLonelyFields) lonely + <> foldMap (reader . forEachFieldTuple) paired + forEachFieldTuple :: forall m + . (FieldInfo m, FieldInfo m) + -> ReaderT (Validation m) Seq Error + forEachFieldTuple (fieldA, fieldB) = + case (parent fieldA, parent fieldB) of + (parentA@Type.CompositeObjectType{}, parentB@Type.CompositeObjectType{}) + | parentA /= parentB -> sameResponseShape fieldA fieldB + _ -> mapReaderT (checkEquality (node fieldA) (node fieldB)) + $ sameResponseShape fieldA fieldB + checkEquality fieldA fieldB Seq.Empty + | Full.Field _ fieldNameA _ _ _ _ <- fieldA + , Full.Field _ fieldNameB _ _ _ _ <- fieldB + , fieldNameA /= fieldNameB = pure $ makeError fieldA fieldB + | Full.Field _ fieldNameA argumentsA _ _ locationA <- fieldA + , Full.Field _ _ argumentsB _ _ locationB <- fieldB + , argumentsA /= argumentsB = + let message = concat + [ "Fields \"" + , Text.unpack fieldNameA + , "\" conflict because they have different arguments. Use " + , "different aliases on the fields to fetch both if this " + , "was intentional." + ] + in pure $ Error message [locationB, locationA] + checkEquality _ _ previousErrors = previousErrors + visitLonelyFields FieldInfo{..} = + let Full.Field _ _ _ _ subSelections _ = node + compositeFieldType = Type.outToComposite type' + in maybe (lift Seq.empty) (go subSelections) compositeFieldType + sameResponseShape :: forall m + . FieldInfo m + -> FieldInfo m + -> ReaderT (Validation m) Seq Error + sameResponseShape fieldA fieldB = + let Full.Field _ _ _ _ selectionsA _ = node fieldA + Full.Field _ _ _ _ selectionsB _ = node fieldB + in case unwrapTypes (type' fieldA) (type' fieldB) of + Left True -> lift mempty + Right (compositeA, compositeB) -> do + validation <- ask + let collectFields' composite = flip runReaderT validation + . flip evalStateT HashSet.empty + . collectFields composite + let collectA = collectFields' compositeA selectionsA + let collectB = collectFields' compositeB selectionsB + fieldsInSetCanMerge + $ foldl' (HashMap.unionWith (<>)) HashMap.empty + $ collectA <> collectB + _ -> pure $ makeError (node fieldA) (node fieldB) + makeError fieldA fieldB = + let Full.Field aliasA fieldNameA _ _ _ locationA = fieldA + Full.Field _ fieldNameB _ _ _ locationB = fieldB + message = concat + [ "Fields \"" + , Text.unpack (fromMaybe fieldNameA aliasA) + , "\" conflict because \"" + , Text.unpack fieldNameB + , "\" and \"" + , Text.unpack fieldNameA + , "\" are different fields. Use different aliases on the fields " + , "to fetch both if this was intentional." + ] + in Error message [locationB, locationA] + unwrapTypes typeA@Out.ScalarBaseType{} typeB@Out.ScalarBaseType{} = + Left $ typeA == typeB + unwrapTypes typeA@Out.EnumBaseType{} typeB@Out.EnumBaseType{} = + Left $ typeA == typeB + unwrapTypes (Out.ListType listA) (Out.ListType listB) = + unwrapTypes listA listB + unwrapTypes (Out.NonNullListType listA) (Out.NonNullListType listB) = + unwrapTypes listA listB + unwrapTypes typeA typeB + | Out.isNonNullType typeA == Out.isNonNullType typeB + , Just compositeA <- Type.outToComposite typeA + , Just compositeB <- Type.outToComposite typeB = + Right (compositeA, compositeB) + | otherwise = Left False + flattenPairs :: forall m + . HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m)) + -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m)) + flattenPairs xs = HashMap.foldr splitSingleFields (Seq.empty, Seq.empty) + $ foldr lookupTypeField [] <$> xs + splitSingleFields :: forall m + . [FieldInfo m] + -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m)) + -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m)) + splitSingleFields [head'] (fields, pairList) = (fields |> head', pairList) + splitSingleFields xs (fields, pairList) = (fields, pairs pairList xs) + lookupTypeField (field, parentType) accumulator = + let Full.Field _ fieldName _ _ _ _ = field + in case Type.lookupCompositeField fieldName parentType of + Nothing -> accumulator + Just (Out.Field _ typeField _) -> + FieldInfo field typeField parentType : accumulator + pairs :: forall m + . Seq (FieldInfo m, FieldInfo m) + -> [FieldInfo m] + -> Seq (FieldInfo m, FieldInfo m) + pairs accumulator [] = accumulator + pairs accumulator (fieldA : fields) = + pair fieldA (pairs accumulator fields) fields + pair _ accumulator [] = accumulator + pair field accumulator (fieldA : fields) = + pair field accumulator fields |> (field, fieldA) + collectFields objectType = accumulateFields objectType mempty + accumulateFields = foldM . forEach + forEach parentType accumulator = \case + Full.FieldSelection fieldSelection -> + forField parentType accumulator fieldSelection + Full.FragmentSpreadSelection fragmentSelection -> + forSpread accumulator fragmentSelection + Full.InlineFragmentSelection fragmentSelection -> + forInline parentType accumulator fragmentSelection + forField parentType accumulator field@(Full.Field alias fieldName _ _ _ _) = + let key = fromMaybe fieldName alias + value = (field, parentType) :| [] + in pure $ HashMap.insertWith (<>) key value accumulator + forSpread accumulator (Full.FragmentSpread fragmentName _ _) = do + inVisitetFragments <- gets $ HashSet.member fragmentName + if inVisitetFragments + then pure accumulator + else collectFromSpread fragmentName accumulator + forInline parentType accumulator = \case + Full.InlineFragment maybeType _ selections _ + | Just typeCondition <- maybeType -> + collectFromFragment typeCondition selections accumulator + | otherwise -> accumulateFields parentType accumulator $ toList selections + collectFromFragment typeCondition selectionSet' accumulator = do + types' <- lift $ asks $ Schema.types . schema + case Type.lookupTypeCondition typeCondition types' of + Nothing -> pure accumulator + Just compositeType -> + accumulateFields compositeType accumulator $ toList selectionSet' + collectFromSpread fragmentName accumulator = do + modify $ HashSet.insert fragmentName + ast' <- lift $ asks ast + case findFragmentDefinition fragmentName ast' of + Nothing -> pure accumulator + Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) -> + collectFromFragment typeCondition selectionSet' accumulator + +data FieldInfo m = FieldInfo + { node :: Full.Field + , type' :: Out.Type m + , parent :: Type.CompositeType m + } diff --git a/tests/Language/GraphQL/ValidateSpec.hs b/tests/Language/GraphQL/ValidateSpec.hs index 318045c..97ca2e9 100644 --- a/tests/Language/GraphQL/ValidateSpec.hs +++ b/tests/Language/GraphQL/ValidateSpec.hs @@ -663,3 +663,81 @@ spec = , locations = [AST.Location 2 15] } in validate queryString `shouldBe` [expected] + + it "fails to merge fields of mismatching types" $ + let queryString = [r| + { + dog { + name: nickname + name + } + } + |] + expected = Error + { message = + "Fields \"name\" conflict because \"nickname\" and \ + \\"name\" are different fields. Use different aliases \ + \on the fields to fetch both if this was intentional." + , locations = [AST.Location 4 19, AST.Location 5 19] + } + in validate queryString `shouldBe` [expected] + + it "fails if the arguments of the same field don't match" $ + let queryString = [r| + { + dog { + doesKnowCommand(dogCommand: SIT) + doesKnowCommand(dogCommand: HEEL) + } + } + |] + expected = Error + { message = + "Fields \"doesKnowCommand\" conflict because they have \ + \different arguments. Use different aliases on the \ + \fields to fetch both if this was intentional." + , locations = [AST.Location 4 19, AST.Location 5 19] + } + in validate queryString `shouldBe` [expected] + + it "fails to merge same-named field and alias" $ + let queryString = [r| + { + dog { + doesKnowCommand(dogCommand: SIT) + doesKnowCommand: isHousetrained(atOtherHomes: true) + } + } + |] + expected = Error + { message = + "Fields \"doesKnowCommand\" conflict because \ + \\"doesKnowCommand\" and \"isHousetrained\" are \ + \different fields. Use different aliases on the fields \ + \to fetch both if this was intentional." + , locations = [AST.Location 4 19, AST.Location 5 19] + } + in validate queryString `shouldBe` [expected] + + it "looks for fields after a successfully merged field pair" $ + let queryString = [r| + { + dog { + name + doesKnowCommand(dogCommand: SIT) + } + dog { + name + doesKnowCommand: isHousetrained(atOtherHomes: true) + } + } + |] + expected = Error + { message = + "Fields \"doesKnowCommand\" conflict because \ + \\"doesKnowCommand\" and \"isHousetrained\" are \ + \different fields. Use different aliases on the fields \ + \to fetch both if this was intentional." + , locations = [AST.Location 5 19, AST.Location 9 19] + } + in validate queryString `shouldBe` [expected]