MAX_CYCLE = 3000

from typing import List, Optional

from ..exc import unimplemented

from .base import VariableTracker
from .constant import ConstantVariable


class IteratorVariable(VariableTracker):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def next_variables(self, tx):
        unimplemented("abstract method, must implement")


class RepeatIteratorVariable(IteratorVariable):
    def __init__(self, item: VariableTracker, **kwargs):
        super().__init__(**kwargs)
        self.item = item

    # Repeat needs no mutation, clone self
    def next_variables(self, tx):
        return self.item.clone(), self


class CountIteratorVariable(IteratorVariable):
    def __init__(self, item: int = 0, step: int = 1, **kwargs):
        super().__init__(**kwargs)
        if not isinstance(item, VariableTracker):
            item = ConstantVariable.create(item)
        if not isinstance(step, VariableTracker):
            step = ConstantVariable.create(step)
        self.item = item
        self.step = step

    def next_variables(self, tx):
        assert self.mutable_local
        next_item = self.item.call_method(tx, "__add__", [self.step], {})
        next_iter = self.clone(item=next_item)
        tx.replace_all(self, next_iter)
        return self.item, next_iter


class CycleIteratorVariable(IteratorVariable):
    def __init__(
        self,
        iterator: IteratorVariable,
        saved: List[VariableTracker] = None,
        saved_index: int = 0,
        item: Optional[VariableTracker] = None,
        **kwargs,
    ):
        if saved is None:
            saved = []
        super().__init__(**kwargs)
        self.iterator = iterator
        self.saved = saved
        self.saved_index = saved_index
        self.item = item

    def next_variables(self, tx):
        assert self.mutable_local

        if self.iterator is not None:
            try:
                new_item, next_inner_iter = self.iterator.next_variables(tx)
                tx.replace_all(self.iterator, next_inner_iter)
                if len(self.saved) > MAX_CYCLE:
                    unimplemented(
                        "input iterator to itertools.cycle has too many items"
                    )
                next_iter = self.clone(
                    iterator=next_inner_iter,
                    saved=self.saved + [new_item],
                    item=new_item,
                )

                tx.replace_all(self, next_iter)
                if self.item is None:
                    return next_iter.next_variables(tx)
                return self.item, next_iter
            except StopIteration:
                next_iter = self.clone(iterator=None)
                # this is redundant as next_iter will do the same
                # but we do it anyway for safety
                tx.replace_all(self, next_iter)
                return next_iter.next_variables(tx)
        elif len(self.saved) > 0:
            next_iter = self.clone(
                saved_index=(self.saved_index + 1) % len(self.saved),
                item=self.saved[self.saved_index],
            )
            tx.replace_all(self, next_iter)
            return self.item, next_iter
        else:
            raise StopIteration
        return self.item, next_iter
