diff options
Diffstat (limited to 'src/Language/GraphQL/Validate/Rules.hs')
| -rw-r--r-- | src/Language/GraphQL/Validate/Rules.hs | 32 |
1 files changed, 15 insertions, 17 deletions
diff --git a/src/Language/GraphQL/Validate/Rules.hs b/src/Language/GraphQL/Validate/Rules.hs index 2d7adba..e60d39d 100644 --- a/src/Language/GraphQL/Validate/Rules.hs +++ b/src/Language/GraphQL/Validate/Rules.hs @@ -50,14 +50,15 @@ import Control.Monad.Trans.Class (MonadTrans(..)) import Control.Monad.Trans.Reader (ReaderT(..), ask, asks, mapReaderT) import Control.Monad.Trans.State (StateT, evalStateT, gets, modify) import Data.Bifunctor (first) -import Data.Foldable (find, fold, foldl', toList) +import Data.Foldable (Foldable(..), find) import qualified Data.HashMap.Strict as HashMap import Data.HashMap.Strict (HashMap) import Data.HashSet (HashSet) import qualified Data.HashSet as HashSet -import Data.List (groupBy, sortBy, sortOn) +import Data.List (sortBy) import Data.Maybe (fromMaybe, isJust, isNothing, mapMaybe) import Data.List.NonEmpty (NonEmpty(..)) +import qualified Data.List.NonEmpty as NonEmpty import Data.Ord (comparing) import Data.Sequence (Seq(..), (|>)) import qualified Data.Sequence as Seq @@ -253,14 +254,16 @@ findDuplicates :: (Full.Definition -> [Full.Location] -> [Full.Location]) -> Full.Location -> String -> RuleT m -findDuplicates filterByName thisLocation errorMessage = do - ast' <- asks ast - let locations' = foldr filterByName [] ast' - if length locations' > 1 && head locations' == thisLocation - then pure $ error' locations' - else lift mempty +findDuplicates filterByName thisLocation errorMessage = + asks ast >>= go . foldr filterByName [] where - error' locations' = Error + go locations' = + case locations' of + headLocation : otherLocations -- length locations' > 1 + | not $ null otherLocations + , headLocation == thisLocation -> pure $ makeError locations' + _ -> lift mempty + makeError locations' = Error { message = errorMessage , locations = locations' } @@ -536,11 +539,6 @@ uniqueDirectiveNamesRule = DirectivesRule extract (Full.Directive directiveName _ location') = (directiveName, location') -groupSorted :: forall a. (a -> Text) -> [a] -> [[a]] -groupSorted getName = groupBy equalByName . sortOn getName - where - equalByName lhs rhs = getName lhs == getName rhs - filterDuplicates :: forall a . (a -> (Text, Full.Location)) -> String @@ -549,12 +547,12 @@ filterDuplicates :: forall a filterDuplicates extract nodeType = Seq.fromList . fmap makeError . filter ((> 1) . length) - . groupSorted getName + . NonEmpty.groupAllWith getName where getName = fst . extract makeError directives' = Error - { message = makeMessage $ head directives' - , locations = snd . extract <$> directives' + { message = makeMessage $ NonEmpty.head directives' + , locations = snd . extract <$> toList directives' } makeMessage directive = concat [ "There can be only one " |
