Implement basic "Field Selection Merging" rule

This commit is contained in:
Eugen Wissner 2020-11-15 10:11:09 +01:00
parent f5209481aa
commit 1f4eb6fb9b
4 changed files with 310 additions and 35 deletions

View File

@ -7,6 +7,9 @@ and this project adheres to
[Haskell Package Versioning Policy](https://pvp.haskell.org/). [Haskell Package Versioning Policy](https://pvp.haskell.org/).
## [Unreleased] ## [Unreleased]
### Added
- `Validate.Rules`:
- `overlappingFieldsCanBeMergedRule`
## [0.11.0.0] - 2020-11-07 ## [0.11.0.0] - 2020-11-07
### Changed ### Changed

View File

@ -15,10 +15,12 @@ module Language.GraphQL.Type.Internal
, directives , directives
, doesFragmentTypeApply , doesFragmentTypeApply
, instanceOf , instanceOf
, lookupCompositeField
, lookupInputType , lookupInputType
, lookupTypeCondition , lookupTypeCondition
, lookupTypeField , lookupTypeField
, mutation , mutation
, outToComposite
, subscription , subscription
, query , query
, types , types
@ -160,12 +162,16 @@ lookupInputType (Full.TypeNonNull (Full.NonNullTypeList nonNull)) types'
<$> lookupInputType nonNull types' <$> lookupInputType nonNull types'
lookupTypeField :: forall a. Full.Name -> Out.Type a -> Maybe (Out.Field a) lookupTypeField :: forall a. Full.Name -> Out.Type a -> Maybe (Out.Field a)
lookupTypeField fieldName = \case lookupTypeField fieldName outputType =
Out.ObjectBaseType objectType -> outToComposite outputType >>= lookupCompositeField fieldName
objectChild objectType
Out.InterfaceBaseType interfaceType -> lookupCompositeField :: forall a
interfaceChild interfaceType . Full.Name
Out.ListBaseType listType -> lookupTypeField fieldName listType -> CompositeType a
-> Maybe (Out.Field a)
lookupCompositeField fieldName = \case
CompositeObjectType objectType -> objectChild objectType
CompositeInterfaceType interfaceType -> interfaceChild interfaceType
_ -> Nothing _ -> Nothing
where where
objectChild (Out.ObjectType _ _ _ resolvers) = objectChild (Out.ObjectType _ _ _ resolvers) =
@ -174,3 +180,12 @@ lookupTypeField fieldName = \case
HashMap.lookup fieldName fields HashMap.lookup fieldName fields
resolverType (Out.ValueResolver objectField _) = objectField resolverType (Out.ValueResolver objectField _) = objectField
resolverType (Out.EventStreamResolver objectField _ _) = objectField resolverType (Out.EventStreamResolver objectField _ _) = objectField
outToComposite :: forall a. Out.Type a -> Maybe (CompositeType a)
outToComposite = \case
Out.ObjectBaseType objectType -> Just $ CompositeObjectType objectType
Out.InterfaceBaseType interfaceType ->
Just $ CompositeInterfaceType interfaceType
Out.UnionBaseType unionType -> Just $ CompositeUnionType unionType
Out.ListBaseType listType -> outToComposite listType
_ -> Nothing

View File

@ -24,6 +24,7 @@ module Language.GraphQL.Validate.Rules
, noUndefinedVariablesRule , noUndefinedVariablesRule
, noUnusedFragmentsRule , noUnusedFragmentsRule
, noUnusedVariablesRule , noUnusedVariablesRule
, overlappingFieldsCanBeMergedRule
, providedRequiredInputFieldsRule , providedRequiredInputFieldsRule
, providedRequiredArgumentsRule , providedRequiredArgumentsRule
, scalarLeafsRule , scalarLeafsRule
@ -40,17 +41,17 @@ module Language.GraphQL.Validate.Rules
import Control.Monad ((>=>), foldM) import Control.Monad ((>=>), foldM)
import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Class (MonadTrans(..))
import Control.Monad.Trans.Reader (ReaderT(..), asks, mapReaderT) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT)
import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify)
import Data.Bifunctor (first) import Data.Bifunctor (first)
import Data.Foldable (find, toList) import Data.Foldable (find, foldl', toList)
import qualified Data.HashMap.Strict as HashMap import qualified Data.HashMap.Strict as HashMap
import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict (HashMap)
import Data.HashSet (HashSet) import Data.HashSet (HashSet)
import qualified Data.HashSet as HashSet import qualified Data.HashSet as HashSet
import Data.List (groupBy, sortBy, sortOn) import Data.List (groupBy, sortBy, sortOn)
import Data.Maybe (isNothing, mapMaybe) import Data.Maybe (fromMaybe, isNothing, mapMaybe)
import Data.List.NonEmpty (NonEmpty) import Data.List.NonEmpty (NonEmpty(..))
import Data.Ord (comparing) import Data.Ord (comparing)
import Data.Sequence (Seq(..), (|>)) import Data.Sequence (Seq(..), (|>))
import qualified Data.Sequence as Seq import qualified Data.Sequence as Seq
@ -80,6 +81,7 @@ specifiedRules =
-- Fields -- Fields
, fieldsOnCorrectTypeRule , fieldsOnCorrectTypeRule
, scalarLeafsRule , scalarLeafsRule
, overlappingFieldsCanBeMergedRule
-- Arguments. -- Arguments.
, knownArgumentNamesRule , knownArgumentNamesRule
, uniqueArgumentNamesRule , uniqueArgumentNamesRule
@ -123,8 +125,8 @@ executableDefinitionsRule = DefinitionRule $ \case
singleFieldSubscriptionsRule :: forall m. Rule m singleFieldSubscriptionsRule :: forall m. Rule m
singleFieldSubscriptionsRule = OperationDefinitionRule $ \case singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do Full.OperationDefinition Full.Subscription name' _ _ rootFields location' -> do
groupedFieldSet <- collectFields rootFields groupedFieldSet <- evalStateT (collectFields rootFields) HashSet.empty
case length groupedFieldSet of case HashSet.size groupedFieldSet of
1 -> lift mempty 1 -> lift mempty
_ _
| Just name <- name' -> pure $ Error | Just name <- name' -> pure $ Error
@ -143,26 +145,18 @@ singleFieldSubscriptionsRule = OperationDefinitionRule $ \case
where where
errorMessage = errorMessage =
"Anonymous Subscription must select only one top level field." "Anonymous Subscription must select only one top level field."
collectFields selectionSet = foldM forEach HashSet.empty selectionSet
collectFields :: forall m
. NonEmpty Full.Selection
-> ReaderT (Validation m) Seq [[Full.Field]]
collectFields selectionSet = evalStateT go HashSet.empty
where
go = groupSorted getFieldName <$> accumulateFields [] selectionSet
getFieldName (Full.Field alias name _ _ _ _)
| Just aliasedName <- alias = aliasedName
| otherwise = name
accumulateFields = foldM forEach
forEach accumulator = \case forEach accumulator = \case
Full.FieldSelection fieldSelection -> forField accumulator fieldSelection Full.FieldSelection fieldSelection -> forField accumulator fieldSelection
Full.FragmentSpreadSelection fragmentSelection -> Full.FragmentSpreadSelection fragmentSelection ->
forSpread accumulator fragmentSelection forSpread accumulator fragmentSelection
Full.InlineFragmentSelection fragmentSelection -> Full.InlineFragmentSelection fragmentSelection ->
forInline accumulator fragmentSelection forInline accumulator fragmentSelection
forField accumulator field@(Full.Field _ _ _ directives' _ _) forField accumulator (Full.Field alias name _ directives' _ _)
| any skip directives' = pure accumulator | any skip directives' = pure accumulator
| otherwise = pure $ field : accumulator | Just aliasedName <- alias = pure
$ HashSet.insert aliasedName accumulator
| otherwise = pure $ HashSet.insert name accumulator
forSpread accumulator (Full.FragmentSpread fragmentName directives' _) forSpread accumulator (Full.FragmentSpread fragmentName directives' _)
| any skip directives' = pure accumulator | any skip directives' = pure accumulator
| otherwise = do | otherwise = do
@ -174,8 +168,14 @@ collectFields selectionSet = evalStateT go HashSet.empty
| any skip directives' = pure accumulator | any skip directives' = pure accumulator
| Just typeCondition <- maybeType = | Just typeCondition <- maybeType =
collectFromFragment typeCondition selections accumulator collectFromFragment typeCondition selections accumulator
| otherwise = accumulateFields accumulator selections | otherwise = HashSet.union accumulator
collectFromFragment typeCondition selectionSet' accumulator = do <$> collectFields selections
skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean True == argumentValue
skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean False == argumentValue
skip _ = False
collectFromFragment typeCondition selectionSet accumulator = do
types' <- lift $ asks $ Schema.types . schema types' <- lift $ asks $ Schema.types . schema
schema' <- lift $ asks schema schema' <- lift $ asks schema
case Type.lookupTypeCondition typeCondition types' of case Type.lookupTypeCondition typeCondition types' of
@ -183,20 +183,15 @@ collectFields selectionSet = evalStateT go HashSet.empty
Just compositeType Just compositeType
| Just objectType <- Schema.subscription schema' | Just objectType <- Schema.subscription schema'
, True <- Type.doesFragmentTypeApply compositeType objectType -> , True <- Type.doesFragmentTypeApply compositeType objectType ->
accumulateFields accumulator selectionSet' HashSet.union accumulator <$> collectFields selectionSet
| otherwise -> pure accumulator | otherwise -> pure accumulator
collectFromSpread fragmentName accumulator = do collectFromSpread fragmentName accumulator = do
modify $ HashSet.insert fragmentName modify $ HashSet.insert fragmentName
ast' <- lift $ asks ast ast' <- lift $ asks ast
case findFragmentDefinition fragmentName ast' of case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator Nothing -> pure accumulator
Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) -> Just (Full.FragmentDefinition _ typeCondition _ selectionSet _) ->
collectFromFragment typeCondition selectionSet' accumulator collectFromFragment typeCondition selectionSet accumulator
skip (Full.Directive "skip" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean True == argumentValue
skip (Full.Directive "include" [Full.Argument "if" (Full.Node argumentValue _) _] _) =
Full.Boolean False == argumentValue
skip _ = False
-- | GraphQL allows a shorthand form for defining query operations when only -- | GraphQL allows a shorthand form for defining query operations when only
-- that one operation exists in the document. -- that one operation exists in the document.
@ -1026,3 +1021,187 @@ providedRequiredInputFieldsRule = ValueRule go constGo
, Text.unpack typeName , Text.unpack typeName
, "\" is required, but it was not provided." , "\" is required, but it was not provided."
] ]
-- | If multiple field selections with the same response names are encountered
-- during execution, the field and arguments to execute and the resulting value
-- should be unambiguous. Therefore any two field selections which might both be
-- encountered for the same object are only valid if they are equivalent.
--
-- For simple handwritten GraphQL, this rule is obviously a clear developer
-- error, however nested fragments can make this difficult to detect manually.
overlappingFieldsCanBeMergedRule :: Rule m
overlappingFieldsCanBeMergedRule = OperationDefinitionRule $ \case
Full.SelectionSet selectionSet _ -> do
schema' <- asks schema
go (toList selectionSet)
$ Type.CompositeObjectType
$ Schema.query schema'
Full.OperationDefinition operationType _ _ _ selectionSet _ -> do
schema' <- asks schema
let root = go (toList selectionSet) . Type.CompositeObjectType
case operationType of
Full.Query -> root $ Schema.query schema'
Full.Mutation
| Just objectType <- Schema.mutation schema' -> root objectType
Full.Subscription
| Just objectType <- Schema.mutation schema' -> root objectType
_ -> lift mempty
where
go selectionSet selectionType = do
fieldTuples <- evalStateT (collectFields selectionType selectionSet) HashSet.empty
fieldsInSetCanMerge fieldTuples
fieldsInSetCanMerge :: forall m
. HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m))
-> ReaderT (Validation m) Seq Error
fieldsInSetCanMerge fieldTuples = do
validation <- ask
let (lonely, paired) = flattenPairs fieldTuples
let reader = flip runReaderT validation
lift $ foldMap (reader . visitLonelyFields) lonely
<> foldMap (reader . forEachFieldTuple) paired
forEachFieldTuple :: forall m
. (FieldInfo m, FieldInfo m)
-> ReaderT (Validation m) Seq Error
forEachFieldTuple (fieldA, fieldB) =
case (parent fieldA, parent fieldB) of
(parentA@Type.CompositeObjectType{}, parentB@Type.CompositeObjectType{})
| parentA /= parentB -> sameResponseShape fieldA fieldB
_ -> mapReaderT (checkEquality (node fieldA) (node fieldB))
$ sameResponseShape fieldA fieldB
checkEquality fieldA fieldB Seq.Empty
| Full.Field _ fieldNameA _ _ _ _ <- fieldA
, Full.Field _ fieldNameB _ _ _ _ <- fieldB
, fieldNameA /= fieldNameB = pure $ makeError fieldA fieldB
| Full.Field _ fieldNameA argumentsA _ _ locationA <- fieldA
, Full.Field _ _ argumentsB _ _ locationB <- fieldB
, argumentsA /= argumentsB =
let message = concat
[ "Fields \""
, Text.unpack fieldNameA
, "\" conflict because they have different arguments. Use "
, "different aliases on the fields to fetch both if this "
, "was intentional."
]
in pure $ Error message [locationB, locationA]
checkEquality _ _ previousErrors = previousErrors
visitLonelyFields FieldInfo{..} =
let Full.Field _ _ _ _ subSelections _ = node
compositeFieldType = Type.outToComposite type'
in maybe (lift Seq.empty) (go subSelections) compositeFieldType
sameResponseShape :: forall m
. FieldInfo m
-> FieldInfo m
-> ReaderT (Validation m) Seq Error
sameResponseShape fieldA fieldB =
let Full.Field _ _ _ _ selectionsA _ = node fieldA
Full.Field _ _ _ _ selectionsB _ = node fieldB
in case unwrapTypes (type' fieldA) (type' fieldB) of
Left True -> lift mempty
Right (compositeA, compositeB) -> do
validation <- ask
let collectFields' composite = flip runReaderT validation
. flip evalStateT HashSet.empty
. collectFields composite
let collectA = collectFields' compositeA selectionsA
let collectB = collectFields' compositeB selectionsB
fieldsInSetCanMerge
$ foldl' (HashMap.unionWith (<>)) HashMap.empty
$ collectA <> collectB
_ -> pure $ makeError (node fieldA) (node fieldB)
makeError fieldA fieldB =
let Full.Field aliasA fieldNameA _ _ _ locationA = fieldA
Full.Field _ fieldNameB _ _ _ locationB = fieldB
message = concat
[ "Fields \""
, Text.unpack (fromMaybe fieldNameA aliasA)
, "\" conflict because \""
, Text.unpack fieldNameB
, "\" and \""
, Text.unpack fieldNameA
, "\" are different fields. Use different aliases on the fields "
, "to fetch both if this was intentional."
]
in Error message [locationB, locationA]
unwrapTypes typeA@Out.ScalarBaseType{} typeB@Out.ScalarBaseType{} =
Left $ typeA == typeB
unwrapTypes typeA@Out.EnumBaseType{} typeB@Out.EnumBaseType{} =
Left $ typeA == typeB
unwrapTypes (Out.ListType listA) (Out.ListType listB) =
unwrapTypes listA listB
unwrapTypes (Out.NonNullListType listA) (Out.NonNullListType listB) =
unwrapTypes listA listB
unwrapTypes typeA typeB
| Out.isNonNullType typeA == Out.isNonNullType typeB
, Just compositeA <- Type.outToComposite typeA
, Just compositeB <- Type.outToComposite typeB =
Right (compositeA, compositeB)
| otherwise = Left False
flattenPairs :: forall m
. HashMap Full.Name (NonEmpty (Full.Field, Type.CompositeType m))
-> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m))
flattenPairs xs = HashMap.foldr splitSingleFields (Seq.empty, Seq.empty)
$ foldr lookupTypeField [] <$> xs
splitSingleFields :: forall m
. [FieldInfo m]
-> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m))
-> (Seq (FieldInfo m), Seq (FieldInfo m, FieldInfo m))
splitSingleFields [head'] (fields, pairList) = (fields |> head', pairList)
splitSingleFields xs (fields, pairList) = (fields, pairs pairList xs)
lookupTypeField (field, parentType) accumulator =
let Full.Field _ fieldName _ _ _ _ = field
in case Type.lookupCompositeField fieldName parentType of
Nothing -> accumulator
Just (Out.Field _ typeField _) ->
FieldInfo field typeField parentType : accumulator
pairs :: forall m
. Seq (FieldInfo m, FieldInfo m)
-> [FieldInfo m]
-> Seq (FieldInfo m, FieldInfo m)
pairs accumulator [] = accumulator
pairs accumulator (fieldA : fields) =
pair fieldA (pairs accumulator fields) fields
pair _ accumulator [] = accumulator
pair field accumulator (fieldA : fields) =
pair field accumulator fields |> (field, fieldA)
collectFields objectType = accumulateFields objectType mempty
accumulateFields = foldM . forEach
forEach parentType accumulator = \case
Full.FieldSelection fieldSelection ->
forField parentType accumulator fieldSelection
Full.FragmentSpreadSelection fragmentSelection ->
forSpread accumulator fragmentSelection
Full.InlineFragmentSelection fragmentSelection ->
forInline parentType accumulator fragmentSelection
forField parentType accumulator field@(Full.Field alias fieldName _ _ _ _) =
let key = fromMaybe fieldName alias
value = (field, parentType) :| []
in pure $ HashMap.insertWith (<>) key value accumulator
forSpread accumulator (Full.FragmentSpread fragmentName _ _) = do
inVisitetFragments <- gets $ HashSet.member fragmentName
if inVisitetFragments
then pure accumulator
else collectFromSpread fragmentName accumulator
forInline parentType accumulator = \case
Full.InlineFragment maybeType _ selections _
| Just typeCondition <- maybeType ->
collectFromFragment typeCondition selections accumulator
| otherwise -> accumulateFields parentType accumulator $ toList selections
collectFromFragment typeCondition selectionSet' accumulator = do
types' <- lift $ asks $ Schema.types . schema
case Type.lookupTypeCondition typeCondition types' of
Nothing -> pure accumulator
Just compositeType ->
accumulateFields compositeType accumulator $ toList selectionSet'
collectFromSpread fragmentName accumulator = do
modify $ HashSet.insert fragmentName
ast' <- lift $ asks ast
case findFragmentDefinition fragmentName ast' of
Nothing -> pure accumulator
Just (Full.FragmentDefinition _ typeCondition _ selectionSet' _) ->
collectFromFragment typeCondition selectionSet' accumulator
data FieldInfo m = FieldInfo
{ node :: Full.Field
, type' :: Out.Type m
, parent :: Type.CompositeType m
}

View File

@ -663,3 +663,81 @@ spec =
, locations = [AST.Location 2 15] , locations = [AST.Location 2 15]
} }
in validate queryString `shouldBe` [expected] in validate queryString `shouldBe` [expected]
it "fails to merge fields of mismatching types" $
let queryString = [r|
{
dog {
name: nickname
name
}
}
|]
expected = Error
{ message =
"Fields \"name\" conflict because \"nickname\" and \
\\"name\" are different fields. Use different aliases \
\on the fields to fetch both if this was intentional."
, locations = [AST.Location 4 19, AST.Location 5 19]
}
in validate queryString `shouldBe` [expected]
it "fails if the arguments of the same field don't match" $
let queryString = [r|
{
dog {
doesKnowCommand(dogCommand: SIT)
doesKnowCommand(dogCommand: HEEL)
}
}
|]
expected = Error
{ message =
"Fields \"doesKnowCommand\" conflict because they have \
\different arguments. Use different aliases on the \
\fields to fetch both if this was intentional."
, locations = [AST.Location 4 19, AST.Location 5 19]
}
in validate queryString `shouldBe` [expected]
it "fails to merge same-named field and alias" $
let queryString = [r|
{
dog {
doesKnowCommand(dogCommand: SIT)
doesKnowCommand: isHousetrained(atOtherHomes: true)
}
}
|]
expected = Error
{ message =
"Fields \"doesKnowCommand\" conflict because \
\\"doesKnowCommand\" and \"isHousetrained\" are \
\different fields. Use different aliases on the fields \
\to fetch both if this was intentional."
, locations = [AST.Location 4 19, AST.Location 5 19]
}
in validate queryString `shouldBe` [expected]
it "looks for fields after a successfully merged field pair" $
let queryString = [r|
{
dog {
name
doesKnowCommand(dogCommand: SIT)
}
dog {
name
doesKnowCommand: isHousetrained(atOtherHomes: true)
}
}
|]
expected = Error
{ message =
"Fields \"doesKnowCommand\" conflict because \
\\"doesKnowCommand\" and \"isHousetrained\" are \
\different fields. Use different aliases on the fields \
\to fetch both if this was intentional."
, locations = [AST.Location 5 19, AST.Location 9 19]
}
in validate queryString `shouldBe` [expected]