Viewing file: test_recipes.py (19.1 KB) -rw-r--r-- Select action/file-type: (+) | (+) | (+) | Code (+) | Session (+) | (+) | SDB (+) | (+) | (+) | (+) | (+) | (+) |
from doctest import DocTestSuite from unittest import TestCase
from itertools import combinations from six.moves import range
import more_itertools as mi
def load_tests(loader, tests, ignore): # Add the doctests tests.addTests(DocTestSuite('more_itertools.recipes')) return tests
class AccumulateTests(TestCase): """Tests for ``accumulate()``"""
def test_empty(self): """Test that an empty input returns an empty output""" self.assertEqual(list(mi.accumulate([])), [])
def test_default(self): """Test accumulate with the default function (addition)""" self.assertEqual(list(mi.accumulate([1, 2, 3])), [1, 3, 6])
def test_bogus_function(self): """Test accumulate with an invalid function""" with self.assertRaises(TypeError): list(mi.accumulate([1, 2, 3], func=lambda x: x))
def test_custom_function(self): """Test accumulate with a custom function""" self.assertEqual( list(mi.accumulate((1, 2, 3, 2, 1), func=max)), [1, 2, 3, 3, 3] )
class TakeTests(TestCase): """Tests for ``take()``"""
def test_simple_take(self): """Test basic usage""" t = mi.take(5, range(10)) self.assertEqual(t, [0, 1, 2, 3, 4])
def test_null_take(self): """Check the null case""" t = mi.take(0, range(10)) self.assertEqual(t, [])
def test_negative_take(self): """Make sure taking negative items results in a ValueError""" self.assertRaises(ValueError, lambda: mi.take(-3, range(10)))
def test_take_too_much(self): """Taking more than an iterator has remaining should return what the iterator has remaining.
""" t = mi.take(10, range(5)) self.assertEqual(t, [0, 1, 2, 3, 4])
class TabulateTests(TestCase): """Tests for ``tabulate()``"""
def test_simple_tabulate(self): """Test the happy path""" t = mi.tabulate(lambda x: x) f = tuple([next(t) for _ in range(3)]) self.assertEqual(f, (0, 1, 2))
def test_count(self): """Ensure tabulate accepts specific count""" t = mi.tabulate(lambda x: 2 * x, -1) f = (next(t), next(t), next(t)) self.assertEqual(f, (-2, 0, 2))
class TailTests(TestCase): """Tests for ``tail()``"""
def test_greater(self): """Length of iterable is greather than requested tail""" self.assertEqual(list(mi.tail(3, 'ABCDEFG')), ['E', 'F', 'G'])
def test_equal(self): """Length of iterable is equal to the requested tail""" self.assertEqual( list(mi.tail(7, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] )
def test_less(self): """Length of iterable is less than requested tail""" self.assertEqual( list(mi.tail(8, 'ABCDEFG')), ['A', 'B', 'C', 'D', 'E', 'F', 'G'] )
class ConsumeTests(TestCase): """Tests for ``consume()``"""
def test_sanity(self): """Test basic functionality""" r = (x for x in range(10)) mi.consume(r, 3) self.assertEqual(3, next(r))
def test_null_consume(self): """Check the null case""" r = (x for x in range(10)) mi.consume(r, 0) self.assertEqual(0, next(r))
def test_negative_consume(self): """Check that negative consumsion throws an error""" r = (x for x in range(10)) self.assertRaises(ValueError, lambda: mi.consume(r, -1))
def test_total_consume(self): """Check that iterator is totally consumed by default""" r = (x for x in range(10)) mi.consume(r) self.assertRaises(StopIteration, lambda: next(r))
class NthTests(TestCase): """Tests for ``nth()``"""
def test_basic(self): """Make sure the nth item is returned""" l = range(10) for i, v in enumerate(l): self.assertEqual(mi.nth(l, i), v)
def test_default(self): """Ensure a default value is returned when nth item not found""" l = range(3) self.assertEqual(mi.nth(l, 100, "zebra"), "zebra")
def test_negative_item_raises(self): """Ensure asking for a negative item raises an exception""" self.assertRaises(ValueError, lambda: mi.nth(range(10), -3))
class AllEqualTests(TestCase): """Tests for ``all_equal()``"""
def test_true(self): """Everything is equal""" self.assertTrue(mi.all_equal('aaaaaa')) self.assertTrue(mi.all_equal([0, 0, 0, 0]))
def test_false(self): """Not everything is equal""" self.assertFalse(mi.all_equal('aaaaab')) self.assertFalse(mi.all_equal([0, 0, 0, 1]))
def test_tricky(self): """Not everything is identical, but everything is equal""" items = [1, complex(1, 0), 1.0] self.assertTrue(mi.all_equal(items))
def test_empty(self): """Return True if the iterable is empty""" self.assertTrue(mi.all_equal('')) self.assertTrue(mi.all_equal([]))
def test_one(self): """Return True if the iterable is singular""" self.assertTrue(mi.all_equal('0')) self.assertTrue(mi.all_equal([0]))
class QuantifyTests(TestCase): """Tests for ``quantify()``"""
def test_happy_path(self): """Make sure True count is returned""" q = [True, False, True] self.assertEqual(mi.quantify(q), 2)
def test_custom_predicate(self): """Ensure non-default predicates return as expected""" q = range(10) self.assertEqual(mi.quantify(q, lambda x: x % 2 == 0), 5)
class PadnoneTests(TestCase): """Tests for ``padnone()``"""
def test_happy_path(self): """wrapper iterator should return None indefinitely""" r = range(2) p = mi.padnone(r) self.assertEqual([0, 1, None, None], [next(p) for _ in range(4)])
class NcyclesTests(TestCase): """Tests for ``nyclces()``"""
def test_happy_path(self): """cycle a sequence three times""" r = ["a", "b", "c"] n = mi.ncycles(r, 3) self.assertEqual( ["a", "b", "c", "a", "b", "c", "a", "b", "c"], list(n) )
def test_null_case(self): """asking for 0 cycles should return an empty iterator""" n = mi.ncycles(range(100), 0) self.assertRaises(StopIteration, lambda: next(n))
def test_pathalogical_case(self): """asking for negative cycles should return an empty iterator""" n = mi.ncycles(range(100), -10) self.assertRaises(StopIteration, lambda: next(n))
class DotproductTests(TestCase): """Tests for ``dotproduct()``'"""
def test_happy_path(self): """simple dotproduct example""" self.assertEqual(400, mi.dotproduct([10, 10], [20, 20]))
class FlattenTests(TestCase): """Tests for ``flatten()``"""
def test_basic_usage(self): """ensure list of lists is flattened one level""" f = [[0, 1, 2], [3, 4, 5]] self.assertEqual(list(range(6)), list(mi.flatten(f)))
def test_single_level(self): """ensure list of lists is flattened only one level""" f = [[0, [1, 2]], [[3, 4], 5]] self.assertEqual([0, [1, 2], [3, 4], 5], list(mi.flatten(f)))
class RepeatfuncTests(TestCase): """Tests for ``repeatfunc()``"""
def test_simple_repeat(self): """test simple repeated functions""" r = mi.repeatfunc(lambda: 5) self.assertEqual([5, 5, 5, 5, 5], [next(r) for _ in range(5)])
def test_finite_repeat(self): """ensure limited repeat when times is provided""" r = mi.repeatfunc(lambda: 5, times=5) self.assertEqual([5, 5, 5, 5, 5], list(r))
def test_added_arguments(self): """ensure arguments are applied to the function""" r = mi.repeatfunc(lambda x: x, 2, 3) self.assertEqual([3, 3], list(r))
def test_null_times(self): """repeat 0 should return an empty iterator""" r = mi.repeatfunc(range, 0, 3) self.assertRaises(StopIteration, lambda: next(r))
class PairwiseTests(TestCase): """Tests for ``pairwise()``"""
def test_base_case(self): """ensure an iterable will return pairwise""" p = mi.pairwise([1, 2, 3]) self.assertEqual([(1, 2), (2, 3)], list(p))
def test_short_case(self): """ensure an empty iterator if there's not enough values to pair""" p = mi.pairwise("a") self.assertRaises(StopIteration, lambda: next(p))
class GrouperTests(TestCase): """Tests for ``grouper()``"""
def test_even(self): """Test when group size divides evenly into the length of the iterable.
""" self.assertEqual( list(mi.grouper(3, 'ABCDEF')), [('A', 'B', 'C'), ('D', 'E', 'F')] )
def test_odd(self): """Test when group size does not divide evenly into the length of the iterable.
""" self.assertEqual( list(mi.grouper(3, 'ABCDE')), [('A', 'B', 'C'), ('D', 'E', None)] )
def test_fill_value(self): """Test that the fill value is used to pad the final group""" self.assertEqual( list(mi.grouper(3, 'ABCDE', 'x')), [('A', 'B', 'C'), ('D', 'E', 'x')] )
class RoundrobinTests(TestCase): """Tests for ``roundrobin()``"""
def test_even_groups(self): """Ensure ordered output from evenly populated iterables""" self.assertEqual( list(mi.roundrobin('ABC', [1, 2, 3], range(3))), ['A', 1, 0, 'B', 2, 1, 'C', 3, 2] )
def test_uneven_groups(self): """Ensure ordered output from unevenly populated iterables""" self.assertEqual( list(mi.roundrobin('ABCD', [1, 2], range(0))), ['A', 1, 'B', 2, 'C', 'D'] )
class PartitionTests(TestCase): """Tests for ``partition()``"""
def test_bool(self): """Test when pred() returns a boolean""" lesser, greater = mi.partition(lambda x: x > 5, range(10)) self.assertEqual(list(lesser), [0, 1, 2, 3, 4, 5]) self.assertEqual(list(greater), [6, 7, 8, 9])
def test_arbitrary(self): """Test when pred() returns an integer""" divisibles, remainders = mi.partition(lambda x: x % 3, range(10)) self.assertEqual(list(divisibles), [0, 3, 6, 9]) self.assertEqual(list(remainders), [1, 2, 4, 5, 7, 8])
class PowersetTests(TestCase): """Tests for ``powerset()``"""
def test_combinatorics(self): """Ensure a proper enumeration""" p = mi.powerset([1, 2, 3]) self.assertEqual( list(p), [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)] )
class UniqueEverseenTests(TestCase): """Tests for ``unique_everseen()``"""
def test_everseen(self): """ensure duplicate elements are ignored""" u = mi.unique_everseen('AAAABBBBCCDAABBB') self.assertEqual( ['A', 'B', 'C', 'D'], list(u) )
def test_custom_key(self): """ensure the custom key comparison works""" u = mi.unique_everseen('aAbACCc', key=str.lower) self.assertEqual(list('abC'), list(u))
def test_unhashable(self): """ensure things work for unhashable items""" iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] u = mi.unique_everseen(iterable) self.assertEqual(list(u), ['a', [1, 2, 3]])
def test_unhashable_key(self): """ensure things work for unhashable items with a custom key""" iterable = ['a', [1, 2, 3], [1, 2, 3], 'a'] u = mi.unique_everseen(iterable, key=lambda x: x) self.assertEqual(list(u), ['a', [1, 2, 3]])
class UniqueJustseenTests(TestCase): """Tests for ``unique_justseen()``"""
def test_justseen(self): """ensure only last item is remembered""" u = mi.unique_justseen('AAAABBBCCDABB') self.assertEqual(list('ABCDAB'), list(u))
def test_custom_key(self): """ensure the custom key comparison works""" u = mi.unique_justseen('AABCcAD', str.lower) self.assertEqual(list('ABCAD'), list(u))
class IterExceptTests(TestCase): """Tests for ``iter_except()``"""
def test_exact_exception(self): """ensure the exact specified exception is caught""" l = [1, 2, 3] i = mi.iter_except(l.pop, IndexError) self.assertEqual(list(i), [3, 2, 1])
def test_generic_exception(self): """ensure the generic exception can be caught""" l = [1, 2] i = mi.iter_except(l.pop, Exception) self.assertEqual(list(i), [2, 1])
def test_uncaught_exception_is_raised(self): """ensure a non-specified exception is raised""" l = [1, 2, 3] i = mi.iter_except(l.pop, KeyError) self.assertRaises(IndexError, lambda: list(i))
def test_first(self): """ensure first is run before the function""" l = [1, 2, 3] f = lambda: 25 i = mi.iter_except(l.pop, IndexError, f) self.assertEqual(list(i), [25, 3, 2, 1])
class FirstTrueTests(TestCase): """Tests for ``first_true()``"""
def test_something_true(self): """Test with no keywords""" self.assertEqual(mi.first_true(range(10)), 1)
def test_nothing_true(self): """Test default return value.""" self.assertEqual(mi.first_true([0, 0, 0]), False)
def test_default(self): """Test with a default keyword""" self.assertEqual(mi.first_true([0, 0, 0], default='!'), '!')
def test_pred(self): """Test with a custom predicate""" self.assertEqual( mi.first_true([2, 4, 6], pred=lambda x: x % 3 == 0), 6 )
class RandomProductTests(TestCase): """Tests for ``random_product()``
Since random.choice() has different results with the same seed across python versions 2.x and 3.x, these tests use highly probably events to create predictable outcomes across platforms. """
def test_simple_lists(self): """Ensure that one item is chosen from each list in each pair. Also ensure that each item from each list eventually appears in the chosen combinations.
Odds are roughly 1 in 7.1 * 10e16 that one item from either list will not be chosen after 100 samplings of one item from each list. Just to be safe, better use a known random seed, too.
""" nums = [1, 2, 3] lets = ['a', 'b', 'c'] n, m = zip(*[mi.random_product(nums, lets) for _ in range(100)]) n, m = set(n), set(m) self.assertEqual(n, set(nums)) self.assertEqual(m, set(lets)) self.assertEqual(len(n), len(nums)) self.assertEqual(len(m), len(lets))
def test_list_with_repeat(self): """ensure multiple items are chosen, and that they appear to be chosen from one list then the next, in proper order.
""" nums = [1, 2, 3] lets = ['a', 'b', 'c'] r = list(mi.random_product(nums, lets, repeat=100)) self.assertEqual(2 * 100, len(r)) n, m = set(r[::2]), set(r[1::2]) self.assertEqual(n, set(nums)) self.assertEqual(m, set(lets)) self.assertEqual(len(n), len(nums)) self.assertEqual(len(m), len(lets))
class RandomPermutationTests(TestCase): """Tests for ``random_permutation()``"""
def test_full_permutation(self): """ensure every item from the iterable is returned in a new ordering
15 elements have a 1 in 1.3 * 10e12 of appearing in sorted order, so we fix a seed value just to be sure.
""" i = range(15) r = mi.random_permutation(i) self.assertEqual(set(i), set(r)) if i == r: raise AssertionError("Values were not permuted")
def test_partial_permutation(self): """ensure all returned items are from the iterable, that the returned permutation is of the desired length, and that all items eventually get returned.
Sampling 100 permutations of length 5 from a set of 15 leaves a (2/3)^100 chance that an item will not be chosen. Multiplied by 15 items, there is a 1 in 2.6e16 chance that at least 1 item will not show up in the resulting output. Using a random seed will fix that.
""" items = range(15) item_set = set(items) all_items = set() for _ in range(100): permutation = mi.random_permutation(items, 5) self.assertEqual(len(permutation), 5) permutation_set = set(permutation) self.assertLessEqual(permutation_set, item_set) all_items |= permutation_set self.assertEqual(all_items, item_set)
class RandomCombinationTests(TestCase): """Tests for ``random_combination()``"""
def test_psuedorandomness(self): """ensure different subsets of the iterable get returned over many samplings of random combinations""" items = range(15) all_items = set() for _ in range(50): combination = mi.random_combination(items, 5) all_items |= set(combination) self.assertEqual(all_items, set(items))
def test_no_replacement(self): """ensure that elements are sampled without replacement""" items = range(15) for _ in range(50): combination = mi.random_combination(items, len(items)) self.assertEqual(len(combination), len(set(combination))) self.assertRaises( ValueError, lambda: mi.random_combination(items, len(items) + 1) )
class RandomCombinationWithReplacementTests(TestCase): """Tests for ``random_combination_with_replacement()``"""
def test_replacement(self): """ensure that elements are sampled with replacement""" items = range(5) combo = mi.random_combination_with_replacement(items, len(items) * 2) self.assertEqual(2 * len(items), len(combo)) if len(set(combo)) == len(combo): raise AssertionError("Combination contained no duplicates")
def test_pseudorandomness(self): """ensure different subsets of the iterable get returned over many samplings of random combinations""" items = range(15) all_items = set() for _ in range(50): combination = mi.random_combination_with_replacement(items, 5) all_items |= set(combination) self.assertEqual(all_items, set(items))
class NthCombinationTests(TestCase): def test_basic(self): iterable = 'abcdefg' r = 4 for index, expected in enumerate(combinations(iterable, r)): actual = mi.nth_combination(iterable, r, index) self.assertEqual(actual, expected)
def test_long(self): actual = mi.nth_combination(range(180), 4, 2000000) expected = (2, 12, 35, 126) self.assertEqual(actual, expected)
class PrependTests(TestCase): def test_basic(self): value = 'a' iterator = iter('bcdefg') actual = list(mi.prepend(value, iterator)) expected = list('abcdefg') self.assertEqual(actual, expected)
def test_multiple(self): value = 'ab' iterator = iter('cdefg') actual = tuple(mi.prepend(value, iterator)) expected = ('ab',) + tuple('cdefg') self.assertEqual(actual, expected)
|