diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index dad98ef..68950b5 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -52,7 +52,7 @@ import Data.HashMap.Strict (HashMap) import Data.HashSet (HashSet) import qualified Data.HashSet as HashSet import Data.List (groupBy, sortBy, sortOn) -import Data.Maybe (catMaybes, fromMaybe, isJust, isNothing, mapMaybe) +import Data.Maybe (fromMaybe, isJust, isNothing, mapMaybe) import Data.List.NonEmpty (NonEmpty(..)) import Data.Ord (comparing) import Data.Sequence (Seq(..), (|>)) @@ -359,9 +359,9 @@ fragmentSpreadTypeExistenceRule = SelectionRule $ const $ \case , "\" which doesn't exist in the schema." ] -maybeToSeq :: forall a m. Maybe a -> ReaderT (Validation m) Seq a -maybeToSeq (Just x) = lift $ pure x -maybeToSeq Nothing = lift mempty +maybeToSeq :: forall a. Maybe a -> Seq a +maybeToSeq (Just x) = pure x +maybeToSeq Nothing = mempty -- | Fragments can only be declared on unions, interfaces, and objects. They are -- invalid on scalars. They can only be applied on nonā€leaf fields. This rule @@ -377,7 +377,7 @@ fragmentsOnCompositeTypesRule = FragmentRule definitionRule inlineRule check typeCondition location' = do types' <- asks $ Schema.types . schema -- Skip unknown types, they are checked by another rule. - _ <- maybeToSeq $ HashMap.lookup typeCondition types' + _ <- lift $ maybeToSeq $ HashMap.lookup typeCondition types' case Type.lookupTypeCondition typeCondition types' of Nothing -> pure $ Error { message = errorMessage typeCondition @@ -1240,9 +1240,11 @@ possibleFragmentSpreadsRule = SelectionRule go go _ _ = lift mempty compareTypes typeCondition parentType = do types' <- asks $ Schema.types . schema - fragmentType <- maybeToSeq + fragmentType <- lift + $ maybeToSeq $ Type.lookupTypeCondition typeCondition types' - parentComposite <- maybeToSeq + parentComposite <- lift + $ maybeToSeq $ Type.outToComposite parentType possibleFragments <- getPossibleTypes fragmentType possibleParents <- getPossibleTypes parentComposite @@ -1279,7 +1281,7 @@ findSpreadTarget :: Full.Name -> ReaderT (Validation m1) Seq Full.TypeCondition findSpreadTarget fragmentName = do ast' <- asks ast let target = find (isSpreadTarget fragmentName) ast' - maybeToSeq $ target >>= extractTypeCondition + lift $ maybeToSeq $ target >>= extractTypeCondition where extractTypeCondition (viewFragment -> Just fragmentDefinition) = let Full.FragmentDefinition _ typeCondition _ _ _ = fragmentDefinition @@ -1397,19 +1399,30 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case Just (Schema.Directive _ _ directiveArguments) -> mapArguments variables directiveArguments arguments Nothing -> pure mempty - mapArguments variables argumentTypes - = fmap (Seq.fromList . catMaybes) + mapArguments variables argumentTypes = fmap fold . traverse (findArgumentVariables variables argumentTypes) mapDirectives variables = fmap fold <$> traverse (findDirectiveVariables variables) + findArgumentVariables :: [Full.VariableDefinition] + -> HashMap Full.Name In.Argument + -> Full.Argument + -> ValidationState m (Seq Error) findArgumentVariables variables argumentTypes argument | Full.Argument argumentName argumentValue _ <- argument , Full.Node{ node = Full.Variable variableName, ..} <- argumentValue , Just expectedType <- HashMap.lookup argumentName argumentTypes , findVariableDefinition' <- findVariableDefinition variableName , Just variableDefinition <- find findVariableDefinition' variables = - isVariableUsageAllowed expectedType variableDefinition - | otherwise = pure Nothing + maybeToSeq <$> isVariableUsageAllowed expectedType variableDefinition + | Full.Argument argumentName argumentValue _ <- argument + , Full.Node{ node = Full.Object objectFields, ..} <- argumentValue + , Just typeArgument <- HashMap.lookup argumentName argumentTypes + , In.Argument _ expectedType _ <- typeArgument + , In.InputObjectBaseType inputObjectType <- expectedType + , In.InputObjectType _ _ fieldTypes <- inputObjectType = + fold <$> traverse (traverseObjectField fieldTypes) objectFields + | otherwise = pure mempty + traverseObjectField _fieldTypes = const $ pure mempty findVariableDefinition variableName variableDefinition = let Full.VariableDefinition variableName' _ _ _ = variableDefinition in variableName == variableName' @@ -1418,12 +1431,14 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case , Full.TypeNonNull _ <- variableType = typesCompatibleOrError variableDefinition locationType | Just nullableLocationType <- unwrapInType locationType - , Full.VariableDefinition _ _ variableDefaultValue _ <- variableDefinition + , Full.VariableDefinition _ variableType variableDefaultValue _ <- + variableDefinition , hasNonNullVariableDefaultValue' <- hasNonNullVariableDefaultValue variableDefaultValue , hasLocationDefaultValue <- isJust locationDefaultValue = - if hasNonNullVariableDefaultValue' || hasLocationDefaultValue - then typesCompatibleOrError variableDefinition nullableLocationType + if (hasNonNullVariableDefaultValue' || hasLocationDefaultValue) + && areTypesCompatible variableType nullableLocationType + then pure Nothing else pure $ makeError variableDefinition locationType | otherwise = typesCompatibleOrError variableDefinition locationType typesCompatibleOrError variableDefinition locationType