diff --git a/src/Language/GraphQL/AST/Transform.hs b/src/Language/GraphQL/AST/Transform.hs index 93fb557..d70a163 100644 --- a/src/Language/GraphQL/AST/Transform.hs +++ b/src/Language/GraphQL/AST/Transform.hs @@ -94,17 +94,21 @@ collectFragments = do _ <- fragmentDefinition nextValue collectFragments -fragmentDefinition :: Full.FragmentDefinition -> TransformT (NonEmpty Core.Selection) +fragmentDefinition :: + Full.FragmentDefinition -> + TransformT (NonEmpty Core.Selection) fragmentDefinition (Full.FragmentDefinition name _tc _dirs sels) = do + modify deleteFragmentDefinition selections <- traverse selection sels let newValue = either id pure =<< selections - modify $ moveFragment newValue + modify $ insertFragment newValue liftJust newValue where - moveFragment newValue (Replacement fullFragments emptyFragDefs) = - let newFragments = HashMap.insert name newValue fullFragments - newDefinitions = HashMap.delete name emptyFragDefs - in Replacement newFragments newDefinitions + deleteFragmentDefinition (Replacement fragments' fragmentDefinitions') = + Replacement fragments' $ HashMap.delete name fragmentDefinitions' + insertFragment newValue (Replacement fragments' fragmentDefinitions') = + let newFragments = HashMap.insert name newValue fragments' + in Replacement newFragments fragmentDefinitions' field :: Full.Field -> TransformT Core.Field field (Full.Field a n args _dirs sels) = do diff --git a/tests/Test/FragmentSpec.hs b/tests/Test/FragmentSpec.hs index a102104..69f1344 100644 --- a/tests/Test/FragmentSpec.hs +++ b/tests/Test/FragmentSpec.hs @@ -10,7 +10,13 @@ import Data.List.NonEmpty (NonEmpty(..)) import Data.Text (Text) import Language.GraphQL import qualified Language.GraphQL.Schema as Schema -import Test.Hspec (Spec, describe, it, shouldBe, shouldNotSatisfy) +import Test.Hspec ( Spec + , describe + , it + , shouldBe + , shouldSatisfy + , shouldNotSatisfy + ) import Text.RawString.QQ (r) size :: Schema.Resolver IO @@ -37,6 +43,10 @@ inlineQuery = [r|{ } }|] +hasErrors :: Value -> Bool +hasErrors (Object object') = HashMap.member "errors" object' +hasErrors _ = True + spec :: Spec spec = describe "Inline fragment executor" $ do it "chooses the first selection if the type matches" $ do @@ -91,9 +101,7 @@ spec = describe "Inline fragment executor" $ do }|] actual <- graphql (size :| []) query - let hasErrors (Object object') = HashMap.member "errors" object' - hasErrors _ = True - in actual `shouldNotSatisfy` hasErrors + actual `shouldNotSatisfy` hasErrors it "evaluates nested fragments" $ do let query = [r| @@ -140,3 +148,17 @@ spec = describe "Inline fragment executor" $ do ] ] in actual `shouldBe` expected + + it "rejects recursive" $ do + let query = [r| + { + ...circumferenceFragment + } + + fragment circumferenceFragment on Hat { + ...circumferenceFragment + } + |] + + actual <- graphql (circumference :| []) query + actual `shouldSatisfy` hasErrors