#!/usr/bin/env python3
'''
Script to automate disk SMART testing

Copyright (C) 2010 Canonical Ltd.

Authors
  Jeff Lane <jeffrey.lane@canonical.com>
  Brendan Donegan <brendan.donegan@canonical.com>

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License version 2,
as published by the Free Software Foundation.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

The purpose of this script is to simply interact with an onboard hard disk and
check for SMART capability and then do a little bit of interaction to make sure
we can at least do some limited interaction with the hard disk's SMART
functions.

In this case, we probe to see if SMART is available and enabled, then we run
the short self test.  Return 0 if it's all good, return 1 if it fails.

NOTE: This may not work correctly on systems where the onboard storage is
controlled by a hardware RAID controller, on external RAID systems, SAN, and
USB/eSATA/eSAS attached storage devices.

Changelog:

v1.1: Put delay before first attempt to acces log, rather than after
v1.0: added debugger class and code to allow for verbose debug output if needed

v0.4: corrected some minor things
      added option parsing to allow for many disks, or disks other than
      "/dev/sda"

V0.3: Removed the arbitrary wait time and implemented a polling method
    to shorten the test time.
    Added in Pass/Fail criteria for the final outcome.
    Added in documentation.

V0.2: added minor debug routine

V0.1: Fixed some minor bugs and added the SmartEnabled() function

V0: First draft

'''

import os
import sys
import time
import logging

from subprocess import Popen, PIPE
from argparse import ArgumentParser


class ListHandler(logging.StreamHandler):

    def emit(self, record):
        if isinstance(record.msg, (list, tuple)):
            for msg in record.msg:
                if type(msg) is bytes:
                    msg = msg.decode()
                logger = logging.getLogger(record.name)
                new_record = logger.makeRecord(record.name, record.levelno,
                    record.pathname, record.lineno, msg, record.args,
                    record.exc_info, record.funcName)
                logging.StreamHandler.emit(self, new_record)

        else:
            logging.StreamHandler.emit(self, record)


def is_smart_enabled(disk):
    # Check with smartctl to see if SMART is available and enabled on the disk
    command = 'smartctl -i %s' % disk
    diskinfo_bytes = (Popen(command, stdout=PIPE, shell=True)
                    .communicate()[0])
    diskinfo = diskinfo_bytes.decode().splitlines()

    logging.debug('SMART Info for disk %s', disk)
    logging.debug(diskinfo)

    return (len(diskinfo) > 2
            and 'Enabled' in diskinfo[-2]
            and 'Available' in diskinfo[-3])


def run_smart_test(disk, type='short'):
    ctl_command = 'smartctl -t %s %s' % (type, disk)
    logging.debug('Beginning test with %s', ctl_command)

    smart_proc = Popen(ctl_command, stderr=PIPE, stdout=PIPE,
                       universal_newlines=True, shell=True)
    ctl_output, ctl_error = smart_proc.communicate()

    logging.debug(ctl_error + ctl_output)

    return smart_proc.returncode


def get_smart_entries(disk, type='selftest'):
    entries = []
    command = 'smartctl -l %s %s' % (type, disk)
    stdout = Popen(command, stdout=PIPE, shell=True).stdout

    # Skip intro lines
    while True:
        line = stdout.readline().decode()
        if not line:
            raise Exception('Failed to parse SMART log entries')

        if line.startswith('SMART'):
            break

    # Get lengths from header
    line = stdout.readline().decode()
    if not line.startswith('Num'):
        return entries
    columns = ['number', 'description', 'status',
               'remaining', 'lifetime', 'lba']
    lengths = [line.index(i) for i in line.split()]
    lengths[columns.index('remaining')] += len('Remaining') - len('100%')
    lengths.append(len(line))

    # Get remaining lines
    entries = []
    for line_bytes in stdout.readlines():
        line = line_bytes.decode()
        if line.startswith('#'):
            entry = {}
            for i, column in enumerate(columns):
                entry[column] = line[lengths[i]:lengths[i + 1]].strip()

            # Convert some columns to integers
            entry['number'] = int(entry['number'][1:])
            entry['lifetime'] = int(entry['lifetime'])
            entries.append(entry)

    return entries

def main():
    description = 'Tests that SMART capabilities on disks that support SMART function.'
    parser = ArgumentParser(description=description)
    parser.add_argument('-b', '--block-dev',
                        metavar='DISK',
                        default='/dev/sda',
                        help=('the DISK to run this test against '
                              '[default: %(default)s]'))
    parser.add_argument('-d', '--debug',
                        action='store_true',
                        default=False,
                        help='prints some debug info')
    parser.add_argument('-s', '--sleep',
                        type=int,
                        default=5,
                        help=('number of seconds to sleep between checks '
                              '[default: %(default)s].'))
    parser.add_argument('-t', '--timeout',
                        type=int,
                        help='number of seconds to timeout from sleeping.')
    args = parser.parse_args()

    # Set logging
    format = '%(levelname)-8s %(message)s'
    handler = ListHandler()
    handler.setFormatter(logging.Formatter(format))
    logger = logging.getLogger()
    logger.addHandler(handler)

    if args.debug:
        logger.setLevel(logging.DEBUG)
    else:
        logger.setLevel(logging.INFO)

    # Make sure we're root, because smartctl doesn't work otherwise.
    if not os.geteuid()==0:
        parser.error("You must be root to run this program")

    # If SMART is available and enabled, we proceed.  Otherwise, we exit as the
    # test is pointless in this case.
    disk = args.block_dev
    if not is_smart_enabled(disk):
        logging.warning('SMART not available on %s' % disk)
        return 0

    # Initiate a self test and start polling until the test is done
    previous_entries = get_smart_entries(disk)
    logging.info("Starting SMART self-test on %s" % disk)
    if run_smart_test(disk) != 0:
        logging.error("Error reported during smartctl test")
        return 1

    if len(previous_entries) > 20:
        # Abort the previous instance
        # so that polling can identify the difference
        run_smart_test(disk)
        previous_entries = get_smart_entries(disk)

    # Priming read... this is here in case our test is finished or fails
    # immediate after it begins.
    logging.debug('Polling selftest.log for status')

    while True:
        # Poll every sleep seconds until test is complete$
        time.sleep(args.sleep)

        current_entries = get_smart_entries(disk)
        logging.debug('%s %s %s %s' % (current_entries[0]['number'],
                                       current_entries[0]['description'],
                                       current_entries[0]['status'],
                                       current_entries[0]['remaining']))
        if current_entries != previous_entries \
           and current_entries[0]["status"] != 'Self-test routine in progress':
            break

        if args.timeout is not None:
            if args.timeout <= 0:
                logging.debug('Polling timed out')
                return 1
            else:
                args.timeout -= args.sleep

    status = current_entries[0]['status']

    if status != 'Completed without error':
        log = get_smart_entries(disk)
        logging.error("FAIL: SMART Self-Test appears to have failed for some reason. "
                        "Run 'sudo smartctl -l selftest %s' to see the SMART log" % disk)
        logging.debug("Last self-test run status: %s" % status)
        return 1
    else:
        logging.info("PASS: SMART Self-Test completed without error")
        return 0


if __name__ == '__main__':
    sys.exit(main())
