diff options
| author | Eugen Wissner <belka@caraus.de> | 2020-11-15 10:11:09 +0100 |
|---|---|---|
| committer | Eugen Wissner <belka@caraus.de> | 2020-11-15 10:11:09 +0100 |
| commit | 1f4eb6fb9bf847401b158d83516bd07650353f25 (patch) | |
| tree | 9b6cb45a501a4df4815b925f3ba05e11e4eb6d55 /src/Language/GraphQL/Validate | |
| parent | f5209481aa28fdb5dcd92589839dab0f1cb8d1b9 (diff) | |
| download | graphql-1f4eb6fb9bf847401b158d83516bd07650353f25.tar.gz | |
Implement basic "Field Selection Merging" rule
Diffstat (limited to 'src/Language/GraphQL/Validate')
| -rw-r--r-- | src/Language/GraphQL/Validate/Rules.hs | 237 |
1 files changed, 208 insertions, 29 deletions
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 740131a..d17c1bc 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -24,6 +24,7 @@ module Language.GraphQL.Validate.Rules , noUndefinedVariablesRule , noUnusedFragmentsRule , noUnusedVariablesRule + , overlappingFieldsCanBeMergedRule , providedRequiredInputFieldsRule , providedRequiredArgumentsRule , scalarLeafsRule @@ -40,17 +41,17 @@ module Language.GraphQL.Validate.Rules import Control.Monad ((>=>), foldM) import Control.Monad.Trans.Class (MonadTrans(..)) -import Control.Monad.Trans.Reader (ReaderT(..), asks, mapReaderT) +import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) import Data.Bifunctor (first) -import Data.Foldable (find, toList) +import Data.Foldable (find, foldl', toList) import qualified Data.HashMap.Strict as HashMap import Data.HashMap.Strict (HashMap) import Data.HashSet (HashSet) import qualified Data.HashSet as HashSet import Data.List (groupBy, sortBy, sortOn) -import Data.Maybe (isNothing, mapMaybe) -import Data.List.NonEmpty (NonEmpty) +import Data.Maybe (fromMaybe, isNothing, mapMaybe) +import Data.List.NonEmpty (NonEmpty(..)) import Data.Ord (comparing) import Data.Sequence (Seq(..), (|>)) import qualified Data.Sequence as Seq @@ -80,6 +81,7 @@ specifiedRules = -- Fields , fieldsOnCorrectTypeRule , scalarLeafsRule + , overlappingFieldsCanBeMergedRule -- Arguments. , knownArgumentNamesRule , uniqueArgumentNamesRule @@ -123,8 +125,8 @@ executableDefinitionsRule = DefinitionRule $ \case singleFieldSubscriptionsRule :: forall m. Rule m singleFieldSubscriptionsRule = OperationDefinitionRule $ \case Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do - groupedFieldSet <- collectFields rootFields - case length groupedFieldSet of + groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty + case HashSet.size groupedFieldSet of 1 -> lift mempty _ | Just name <- name' -> pure $ Error @@ -143,26 +145,18 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case where errorMessage = "Anonymous Subscription must select only one top level field." - -collectFields :: forall m - . NonEmpty Full.Selection - -> ReaderT (Validation m) Seq [[Full.Field]] -collectFields selectionSet = evalStateT go HashSet.empty - where - go = groupSorted getFieldName <$> accumulateFields [] selectionSet - getFieldName (Full.Field alias name _ _ _ _) - | Just aliasedName <- alias = aliasedName - | otherwise = name - accumulateFields = foldM forEach + collectFields selectionSet = foldM forEach HashSet.empty selectionSet forEach accumulator = \case Full.FieldSelection fieldSelection -> forField accumulator fieldSelection Full.FragmentSpreadSelection fragmentSelection -> forSpread accumulator fragmentSelection Full.InlineFragmentSelection fragmentSelection -> forInline accumulator fragmentSelection - forField accumulator field@(Full.Field _ _ _ directives' _ _) + forField accumulator (Full.Field alias name _ directives' _ _) | any skip directives' = pure accumulator - | otherwise = pure $ field : accumulator + | Just aliasedName <- alias = pure + $ HashSet.insert aliasedName accumulator + | otherwise = pure $ HashSet.insert name accumulator forSpread accumulator (Full.FragmentSpread fragmentName directives' _) | any skip directives' = pure accumulator | otherwise = do @@ -174,8 +168,14 @@ collectFields selectionSet = evalStateT go HashSet.empty | any skip directives' = pure accumulator | Just typeCondition <- maybeType = collectFromFragment typeCondition selections accumulator - | otherwise = accumulateFields accumulator selections - collectFromFragment typeCondition selectionSet' accumulator = do + | otherwise = HashSet.union accumulator + <$> collectFields selections + skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) = + Full.Boolean True == argumentValue + skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) = + Full.Boolean False == argumentValue + skip _ = False + collectFromFragment typeCondition selectionSet accumulator = do types' <- lift $ asks $ Schema.types . schema schema' <- lift $ asks schema case Type.lookupTypeCondition typeCondition types' of @@ -183,20 +183,15 @@ collectFields selectionSet = evalStateT go HashSet.empty Just compositeType | Just objectType <- Schema.subscription schema' , True <- Type.doesFragmentTypeApply compositeType objectType -> - accumulateFields accumulator selectionSet' + HashSet.union accumulator <$> collectFields selectionSet | otherwise -> pure accumulator collectFromSpread fragmentName accumulator = do modify $ HashSet.insert fragmentName ast' <- lift $ asks ast case findFragmentDefinition fragmentName ast' of Nothing -> pure accumulator - Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) -> - collectFromFragment typeCondition selectionSet' accumulator - skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) = - Full.Boolean True == argumentValue - skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) = - Full.Boolean False == argumentValue - skip _ = False + Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) -> + collectFromFragment typeCondition selectionSet accumulator -- | GraphQL allows a short‐hand form for defining query operations when only -- that one operation exists in the document. @@ -1026,3 +1021,187 @@ providedRequiredInputFieldsRule = ValueRule go constGo , Text.unpack typeName , "\" is required, but it was not provided." ] + +-- | If multiple field selections with the same response names are encountered +-- during execution, the field and arguments to execute and the resulting value +-- should be unambiguous. Therefore any two field selections which might both be +-- encountered for the same object are only valid if they are equivalent. +-- +-- For simple hand‐written GraphQL, this rule is obviously a clear developer +-- error, however nested fragments can make this difficult to detect manually. +overlappingFieldsCanBeMergedRule :: Rule m +overlappingFieldsCanBeMergedRule = OperationDefinitionRule $ \case + Full.SelectionSet selectionSet _ -> do + schema' <- asks schema + go (toList selectionSet) + $ Type.CompositeObjectType + $ Schema.query schema' + Full.OperationDefinition operationType _ _ _ selectionSet _ -> do + schema' <- asks schema + let root = go (toList selectionSet) . Type.CompositeObjectType + case operationType of + Full.Query -> root $ Schema.query schema' + Full.Mutation + | Just objectType <- Schema.mutation schema' -> root objectType + Full.Subscription + | Just objectType <- Schema.mutation schema' -> root objectType + _ -> lift mempty + where + go selectionSet selectionType = do + fieldTuples <- evalStateT (collectFields selectionType selectionSet) HashSet.empty + fieldsInSetCanMerge fieldTuples + fieldsInSetCanMerge :: forall m + . HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m)) + -> ReaderT (Validation m) Seq Error + fieldsInSetCanMerge fieldTuples = do + validation <- ask + let (lonely, paired) = flattenPairs fieldTuples + let reader = flip runReaderT validation + lift $ foldMap (reader . visitLonelyFields) lonely + <> foldMap (reader . forEachFieldTuple) paired + forEachFieldTuple :: forall m + . (FieldInfo m, FieldInfo m) + -> ReaderT (Validation m) Seq Error + forEachFieldTuple (fieldA, fieldB) = + case (parent fieldA, parent fieldB) of + (parentA@Type.CompositeObjectType{}, parentB@Type.CompositeObjectType{}) + | parentA /= parentB -> sameResponseShape fieldA fieldB + _ -> mapReaderT (checkEquality (node fieldA) (node fieldB)) + $ sameResponseShape fieldA fieldB + checkEquality fieldA fieldB Seq.Empty + | Full.Field _ fieldNameA _ _ _ _ <- fieldA + , Full.Field _ fieldNameB _ _ _ _ <- fieldB + , fieldNameA /= fieldNameB = pure $ makeError fieldA fieldB + | Full.Field _ fieldNameA argumentsA _ _ locationA <- fieldA + , Full.Field _ _ argumentsB _ _ locationB <- fieldB + , argumentsA /= argumentsB = + let message = concat + [ "Fields \"" + , Text.unpack fieldNameA + , "\" conflict because they have different arguments. Use " + , "different aliases on the fields to fetch both if this " + , "was intentional." + ] + in pure $ Error message [locationB, locationA] + checkEquality _ _ previousErrors = previousErrors + visitLonelyFields FieldInfo{..} = + let Full.Field _ _ _ _ subSelections _ = node + compositeFieldType = Type.outToComposite type' + in maybe (lift Seq.empty) (go subSelections) compositeFieldType + sameResponseShape :: forall m + . FieldInfo m + -> FieldInfo m + -> ReaderT (Validation m) Seq Error + sameResponseShape fieldA fieldB = + let Full.Field _ _ _ _ selectionsA _ = node fieldA + Full.Field _ _ _ _ selectionsB _ = node fieldB + in case unwrapTypes (type' fieldA) (type' fieldB) of + Left True -> lift mempty + Right (compositeA, compositeB) -> do + validation <- ask + let collectFields' composite = flip runReaderT validation + . flip evalStateT HashSet.empty + . collectFields composite + let collectA = collectFields' compositeA selectionsA + let collectB = collectFields' compositeB selectionsB + fieldsInSetCanMerge + $ foldl' (HashMap.unionWith (<>)) HashMap.empty + $ collectA <> collectB + _ -> pure $ makeError (node fieldA) (node fieldB) + makeError fieldA fieldB = + let Full.Field aliasA fieldNameA _ _ _ locationA = fieldA + Full.Field _ fieldNameB _ _ _ locationB = fieldB + message = concat + [ "Fields \"" + , Text.unpack (fromMaybe fieldNameA aliasA) + , "\" conflict because \"" + , Text.unpack fieldNameB + , "\" and \"" + , Text.unpack fieldNameA + , "\" are different fields. Use different aliases on the fields " + , "to fetch both if this was intentional." + ] + in Error message [locationB, locationA] + unwrapTypes typeA@Out.ScalarBaseType{} typeB@Out.ScalarBaseType{} = + Left $ typeA == typeB + unwrapTypes typeA@Out.EnumBaseType{} typeB@Out.EnumBaseType{} = + Left $ typeA == typeB + unwrapTypes (Out.ListType listA) (Out.ListType listB) = + unwrapTypes listA listB + unwrapTypes (Out.NonNullListType listA) (Out.NonNullListType listB) = + unwrapTypes listA listB + unwrapTypes typeA typeB + | Out.isNonNullType typeA == Out.isNonNullType typeB + , Just compositeA <- Type.outToComposite typeA + , Just compositeB <- Type.outToComposite typeB = + Right (compositeA, compositeB) + | otherwise = Left False + flattenPairs :: forall m + . HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m)) + -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m)) + flattenPairs xs = HashMap.foldr splitSingleFields (Seq.empty, Seq.empty) + $ foldr lookupTypeField [] <$> xs + splitSingleFields :: forall m + . [FieldInfo m] + -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m)) + -> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m)) + splitSingleFields [head'] (fields, pairList) = (fields |> head', pairList) + splitSingleFields xs (fields, pairList) = (fields, pairs pairList xs) + lookupTypeField (field, parentType) accumulator = + let Full.Field _ fieldName _ _ _ _ = field + in case Type.lookupCompositeField fieldName parentType of + Nothing -> accumulator + Just (Out.Field _ typeField _) -> + FieldInfo field typeField parentType : accumulator + pairs :: forall m + . Seq (FieldInfo m, FieldInfo m) + -> [FieldInfo m] + -> Seq (FieldInfo m, FieldInfo m) + pairs accumulator [] = accumulator + pairs accumulator (fieldA : fields) = + pair fieldA (pairs accumulator fields) fields + pair _ accumulator [] = accumulator + pair field accumulator (fieldA : fields) = + pair field accumulator fields |> (field, fieldA) + collectFields objectType = accumulateFields objectType mempty + accumulateFields = foldM . forEach + forEach parentType accumulator = \case + Full.FieldSelection fieldSelection -> + forField parentType accumulator fieldSelection + Full.FragmentSpreadSelection fragmentSelection -> + forSpread accumulator fragmentSelection + Full.InlineFragmentSelection fragmentSelection -> + forInline parentType accumulator fragmentSelection + forField parentType accumulator field@(Full.Field alias fieldName _ _ _ _) = + let key = fromMaybe fieldName alias + value = (field, parentType) :| [] + in pure $ HashMap.insertWith (<>) key value accumulator + forSpread accumulator (Full.FragmentSpread fragmentName _ _) = do + inVisitetFragments <- gets $ HashSet.member fragmentName + if inVisitetFragments + then pure accumulator + else collectFromSpread fragmentName accumulator + forInline parentType accumulator = \case + Full.InlineFragment maybeType _ selections _ + | Just typeCondition <- maybeType -> + collectFromFragment typeCondition selections accumulator + | otherwise -> accumulateFields parentType accumulator $ toList selections + collectFromFragment typeCondition selectionSet' accumulator = do + types' <- lift $ asks $ Schema.types . schema + case Type.lookupTypeCondition typeCondition types' of + Nothing -> pure accumulator + Just compositeType -> + accumulateFields compositeType accumulator $ toList selectionSet' + collectFromSpread fragmentName accumulator = do + modify $ HashSet.insert fragmentName + ast' <- lift $ asks ast + case findFragmentDefinition fragmentName ast' of + Nothing -> pure accumulator + Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) -> + collectFromFragment typeCondition selectionSet' accumulator + +data FieldInfo m = FieldInfo + { node :: Full.Field + , type' :: Out.Type m + , parent :: Type.CompositeType m + } |
