diff --git a/dav_events/forms.py b/dav_events/forms.py index b1fc8ba..f0027c7 100644 --- a/dav_events/forms.py +++ b/dav_events/forms.py @@ -76,6 +76,7 @@ class SetPasswordForm(forms.Form): class ChainedForm(forms.Form): + _initial_form_name = None _next_form_name = None def __init__(self, *args, **kwargs): @@ -126,21 +127,33 @@ class ChainedForm(forms.Form): @property def form_title(self): - if hasattr(self, '_form_title'): - return self._form_title - n = self.form_name - if n.endswith('Form'): - n = n[:-len('Form')] - return n + return self.__class__.get_form_title() + + @property + def initial_form_name(self): + return self.__class__.get_initial_form_name() @property def next_form_name(self): - return self._next_form_name + return self.__class__.get_next_form_name() @classmethod def get_form_name(cls): return cls.__name__ + @classmethod + def get_form_title(cls): + if hasattr(cls, '_form_title'): + return cls._form_title + n = cls.get_form_name() + if n.endswith('Form'): + n = n[:-len('Form')] + return n + + @classmethod + def get_initial_form_name(cls): + return cls._initial_form_name or cls.__name__ + @classmethod def get_next_form_name(cls): return cls._next_form_name @@ -172,6 +185,7 @@ class ChainedForm(forms.Form): class EventCreateForm(ChainedForm): _model = models.Event + _initial_form_name = 'ModeForm' class ModeForm(EventCreateForm): diff --git a/dav_events/views.py b/dav_events/views.py index f5b17d6..fef9263 100644 --- a/dav_events/views.py +++ b/dav_events/views.py @@ -137,7 +137,7 @@ class EventAcceptView(EventDetailView): class EventCreateView(generic.FormView): - initial_form_class = forms.ModeForm + form_class = forms.EventCreateForm template_dir = os.path.join('dav_events', 'event_create') default_template_name = 'default.html' abort_url = reverse_lazy('dav_events:home') @@ -160,7 +160,9 @@ class EventCreateView(generic.FormView): if not issubclass(form_class, forms.ChainedForm): raise SuspiciousOperation('Invalid next form: {}'.format(form_name)) else: - form_class = self.initial_form_class + base_form_class = self.form_class + initial_form_name = base_form_class.get_initial_form_name() + form_class = getattr(forms, initial_form_name) return form_class