Add assert_valid_state so I can stop copypasting this code.

This commit is contained in:
voussoir 2021-10-25 13:20:05 -07:00
parent 787cec38aa
commit d08415eaad
No known key found for this signature in database
GPG key ID: 5F7554F8C26DACCB
2 changed files with 9 additions and 6 deletions

View file

@ -215,8 +215,7 @@ class Channel(ObjectBase):
@worms.transaction @worms.transaction
def set_automark(self, state): def set_automark(self, state):
if state not in constants.VIDEO_STATES: self.ycdldb.assert_valid_state(state)
raise exceptions.InvalidVideoState(state)
pairs = { pairs = {
'id': self.id, 'id': self.id,
@ -339,8 +338,7 @@ class Video(ObjectBase):
Note: Marking as downloaded will not create the queue file, this only Note: Marking as downloaded will not create the queue file, this only
updates the database. See yclddb.download_video. updates the database. See yclddb.download_video.
''' '''
if state not in constants.VIDEO_STATES: self.ycdldb.assert_valid_state(state)
raise exceptions.InvalidVideoState(state)
log.info('Marking %s as %s.', self, state) log.info('Marking %s as %s.', self, state)

View file

@ -45,8 +45,7 @@ class YCDLDBChannelMixin:
except exceptions.NoSuchChannel: except exceptions.NoSuchChannel:
pass pass
if automark not in constants.VIDEO_STATES: self.assert_valid_state(automark)
raise exceptions.InvalidVideoState(automark)
name = objects.Channel.normalize_name(name) name = objects.Channel.normalize_name(name)
if name is None: if name is None:
@ -223,6 +222,7 @@ class YCDLDBVideoMixin:
bindings.append(channel_id) bindings.append(channel_id)
if state is not None: if state is not None:
self.assert_valid_state(state)
wheres.append('state') wheres.append('state')
bindings.append(state) bindings.append(state)
@ -471,6 +471,11 @@ class YCDLDB(
log.debug('Found closest YCDLDB at %s.', path) log.debug('Found closest YCDLDB at %s.', path)
return ycdldb return ycdldb
@staticmethod
def assert_valid_state(state):
if state not in constants.VIDEO_STATES:
raise exceptions.InvalidVideoState(state)
def get_all_states(self): def get_all_states(self):
''' '''
Get a list of all the different states that are currently in use in Get a list of all the different states that are currently in use in