From 5a5f265fe4bf291c1ef58f5fe452f1e8c69c9ed6 Mon Sep 17 00:00:00 2001 From: Eugen Wissner Date: Thu, 6 May 2021 22:23:16 +0200 Subject: Validate non-nullable values inside lists --- src/Language/GraphQL/Validate.hs | 19 ---------------- src/Language/GraphQL/Validate/Rules.hs | 41 +++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 30 deletions(-) (limited to 'src') diff --git a/src/Language/GraphQL/Validate.hs b/src/Language/GraphQL/Validate.hs index 2e86e53..f929b98 100644 --- a/src/Language/GraphQL/Validate.hs +++ b/src/Language/GraphQL/Validate.hs @@ -314,9 +314,6 @@ constValue (Validation.ValueRule _ rule) valueType = go valueType go inputObjectType value'@(Full.Node (Full.ConstObject fields) _) = foldMap (forEach inputObjectType) (Seq.fromList fields) |> rule inputObjectType value' - go listType value'@(Full.Node (Full.ConstList values) _location) - = embedListLocation go listType values - |> rule listType value' go anotherValue value' = pure $ rule anotherValue value' forEach inputObjectType Full.ObjectField{value = value', ..} = go (valueTypeByName name inputObjectType) value' @@ -420,19 +417,6 @@ argument rule argumentType (Full.Argument _ value' _) = where valueType (In.Argument _ valueType' _) = valueType' --- Applies a validation rule to each list value and merges returned errors. -embedListLocation :: forall a m - . (Maybe In.Type -> Full.Node a -> Seq m) - -> Maybe In.Type - -> [Full.Node a] - -> Seq m -embedListLocation go listType - = foldMap (go $ valueTypeFromList listType) - . Seq.fromList - where - valueTypeFromList (Just (In.ListBaseType baseType)) = Just baseType - valueTypeFromList _ = Nothing - value :: forall m . Validation.Rule m -> Maybe In.Type @@ -443,9 +427,6 @@ value (Validation.ValueRule rule _) valueType = go valueType go inputObjectType value'@(Full.Node (Full.Object fields) _) = foldMap (forEach inputObjectType) (Seq.fromList fields) |> rule inputObjectType value' - go listType value'@(Full.Node (Full.List values) _location) - = embedListLocation go listType values - |> rule listType value' go anotherValue value' = pure $ rule anotherValue value' forEach inputObjectType Full.ObjectField{value = value', ..} = go (valueTypeByName name inputObjectType) value' diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index a6bc43b..46a14b7 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -1502,15 +1502,6 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case hasNonNullVariableDefaultValue (Just (Full.Node Full.ConstNull _)) = False hasNonNullVariableDefaultValue Nothing = False hasNonNullVariableDefaultValue _ = True - unwrapInType (In.NonNullScalarType nonNullType) = - Just $ In.NamedScalarType nonNullType - unwrapInType (In.NonNullEnumType nonNullType) = - Just $ In.NamedEnumType nonNullType - unwrapInType (In.NonNullInputObjectType nonNullType) = - Just $ In.NamedInputObjectType nonNullType - unwrapInType (In.NonNullListType nonNullType) = - Just $ In.ListType nonNullType - unwrapInType _ = Nothing makeError variableDefinition expectedType = let Full.VariableDefinition variableName variableType _ location' = variableDefinition @@ -1527,6 +1518,17 @@ variablesInAllowedPositionRule = OperationDefinitionRule $ \case , locations = [location'] } +unwrapInType :: In.Type -> Maybe In.Type +unwrapInType (In.NonNullScalarType nonNullType) = + Just $ In.NamedScalarType nonNullType +unwrapInType (In.NonNullEnumType nonNullType) = + Just $ In.NamedEnumType nonNullType +unwrapInType (In.NonNullInputObjectType nonNullType) = + Just $ In.NamedInputObjectType nonNullType +unwrapInType (In.NonNullListType nonNullType) = + Just $ In.ListType nonNullType +unwrapInType _ = Nothing + -- | Literal values must be compatible with the type expected in the position -- they are found as per the coercion rules. -- @@ -1586,8 +1588,8 @@ valuesOfCorrectTypeRule = ValueRule go constGo | In.InputObjectType{} <- objectType , Full.ConstObject{} <- node = mempty check (In.ListBaseType listType) constValue@Full.Node{ .. } - -- Skip, lists are checked recursively by the validation traverser. - | Full.ConstList{} <- node = mempty + | Full.ConstList values <- node = + foldMap (checkNull listType) values | otherwise = check listType constValue check inputType Full.Node{ .. } = pure $ Error { message = concat @@ -1599,3 +1601,20 @@ valuesOfCorrectTypeRule = ValueRule go constGo ] , locations = [location] } + checkNull inputType constValue = + let checkResult = check inputType constValue + in case null checkResult of + True + | Just unwrappedType <- unwrapInType inputType + , Full.Node{ node = Full.ConstNull, .. } <- constValue -> + pure $ Error + { message = concat + [ "List of non-null values of type \"" + , show unwrappedType + , "\" cannot contain null values." + ] + , locations = [location] + } + | otherwise -> mempty + _ -> checkResult + -- cgit v1.2.3