/* Copyright (C) 2014 Open Information Security Foundation
 *
 * You can copy, redistribute or modify this Program under the terms of
 * the GNU General Public License version 2 as published by the Free
 * Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * version 2 along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 */

/**
 * \file
 *
 * \author Roliers Jean-Paul <popof.fpn@gmail.co>
 * \author Eric Leblond <eric@regit.org>
 * \author Victor Julien <victor@inliniac.net>
 *
 * Implements TLS store portion of the engine.
 *
 */

#include "suricata-common.h"
#include "debug.h"
#include "detect.h"
#include "pkt-var.h"
#include "conf.h"

#include "threads.h"
#include "threadvars.h"
#include "tm-threads.h"

#include "util-print.h"
#include "util-unittest.h"

#include "util-debug.h"

#include "output.h"
#include "log-tlslog.h"
#include "app-layer-ssl.h"
#include "app-layer.h"
#include "app-layer-parser.h"
#include "util-privs.h"
#include "util-buffer.h"

#include "util-logopenfile.h"
#include "util-crypt.h"
#include "util-time.h"

#define MODULE_NAME "LogTlsStoreLog"

static char tls_logfile_base_dir[PATH_MAX] = "/tmp";
SC_ATOMIC_DECLARE(unsigned int, cert_id);
static char logging_dir_not_writable;

#define LOGGING_WRITE_ISSUE_LIMIT 6

typedef struct LogTlsStoreLogThread_ {
    uint32_t tls_cnt;

    uint8_t*   enc_buf;
    size_t     enc_buf_len;
} LogTlsStoreLogThread;

static int CreateFileName(const Packet *p, SSLState *state, char *filename)
{
#define FILELEN 64  //filename len + extention + ending path / + some space

    int filenamelen = FILELEN + strlen(tls_logfile_base_dir);
    int file_id = SC_ATOMIC_ADD(cert_id, 1);

    if (filenamelen + 1 > PATH_MAX) {
        return 0;
    }

    /* Use format : packet time + incremental ID
     * When running on same pcap it will overwrite
     * On a live device, we will not be able to overwrite */
    snprintf(filename, filenamelen, "%s/%ld.%ld-%d.pem",
             tls_logfile_base_dir,
             (long int)p->ts.tv_sec,
             (long int)p->ts.tv_usec,
             file_id);
    return 1;
}

static void LogTlsLogPem(LogTlsStoreLogThread *aft, const Packet *p, SSLState *state, int ipproto)
{
#define PEMHEADER "-----BEGIN CERTIFICATE-----\n"
#define PEMFOOTER "-----END CERTIFICATE-----\n"
    //Logging pem certificate
    char filename[PATH_MAX] = "";
    FILE* fp = NULL;
    FILE* fpmeta = NULL;
    unsigned long pemlen;
    unsigned char* pembase64ptr = NULL;
    int ret;
    uint8_t *ptmp;
    SSLCertsChain *cert;

    if ((state->server_connp.cert_input == NULL) || (state->server_connp.cert_input_len == 0))
        SCReturn;

    CreateFileName(p, state, filename);
    if (strlen(filename) == 0) {
        SCLogWarning(SC_ERR_FOPEN, "Can't create PEM filename");
        SCReturn;
    }

    fp = fopen(filename, "w");
    if (fp == NULL) {
        if (logging_dir_not_writable < LOGGING_WRITE_ISSUE_LIMIT) {
            SCLogWarning(SC_ERR_FOPEN,
                         "Can't create PEM file '%s' in '%s' directory",
                         filename, tls_logfile_base_dir);
            logging_dir_not_writable++;
        }
        SCReturn;
    }

    TAILQ_FOREACH(cert, &state->server_connp.certs, next) {
        pemlen = (4 * (cert->cert_len + 2) / 3) +1;
        if (pemlen > aft->enc_buf_len) {
            ptmp = (uint8_t*) SCRealloc(aft->enc_buf, sizeof(uint8_t) * pemlen);
            if (ptmp == NULL) {
                SCFree(aft->enc_buf);
                aft->enc_buf = NULL;
                SCLogWarning(SC_ERR_MEM_ALLOC, "Can't allocate data for base64 encoding");
                goto end_fp;
            }
            aft->enc_buf = ptmp;
            aft->enc_buf_len = pemlen;
        }

        memset(aft->enc_buf, 0, aft->enc_buf_len);

        ret = Base64Encode((unsigned char*) cert->cert_data, cert->cert_len, aft->enc_buf, &pemlen);
        if (ret != SC_BASE64_OK) {
            SCLogWarning(SC_ERR_INVALID_ARGUMENTS, "Invalid return of Base64Encode function");
            goto end_fwrite_fp;
        }

        if (fprintf(fp, PEMHEADER) < 0)
            goto end_fwrite_fp;

        pembase64ptr = aft->enc_buf;
        while (pemlen > 0) {
            size_t loffset = pemlen >= 64 ? 64 : pemlen;
            if (fwrite(pembase64ptr, 1, loffset, fp) != loffset)
                goto end_fwrite_fp;
            if (fwrite("\n", 1, 1, fp) != 1)
                goto end_fwrite_fp;
            pembase64ptr += 64;
            if (pemlen < 64)
                break;
            pemlen -= 64;
        }

        if (fprintf(fp, PEMFOOTER) < 0)
            goto end_fwrite_fp;
    }
    fclose(fp);

    //Logging certificate informations
    memcpy(filename + (strlen(filename) - 3), "meta", 4);
    fpmeta = fopen(filename, "w");
    if (fpmeta != NULL) {
        #define PRINT_BUF_LEN 46
        char srcip[PRINT_BUF_LEN], dstip[PRINT_BUF_LEN];
        char timebuf[64];
        Port sp, dp;
        CreateTimeString(&p->ts, timebuf, sizeof(timebuf));
        if (!TLSGetIPInformations(p, srcip, PRINT_BUF_LEN, &sp, dstip, PRINT_BUF_LEN, &dp, ipproto))
            goto end_fwrite_fpmeta;
        if (fprintf(fpmeta, "TIME:              %s\n", timebuf) < 0)
            goto end_fwrite_fpmeta;
        if (p->pcap_cnt > 0) {
            if (fprintf(fpmeta, "PCAP PKT NUM:      %"PRIu64"\n", p->pcap_cnt) < 0)
                goto end_fwrite_fpmeta;
        }
        if (fprintf(fpmeta, "SRC IP:            %s\n", srcip) < 0)
            goto end_fwrite_fpmeta;
        if (fprintf(fpmeta, "DST IP:            %s\n", dstip) < 0)
            goto end_fwrite_fpmeta;
        if (fprintf(fpmeta, "PROTO:             %" PRIu32 "\n", p->proto) < 0)
            goto end_fwrite_fpmeta;
        if (PKT_IS_TCP(p) || PKT_IS_UDP(p)) {
            if (fprintf(fpmeta, "SRC PORT:          %" PRIu16 "\n", sp) < 0)
                goto end_fwrite_fpmeta;
            if (fprintf(fpmeta, "DST PORT:          %" PRIu16 "\n", dp) < 0)
                goto end_fwrite_fpmeta;
        }

        if (fprintf(fpmeta, "TLS SUBJECT:       %s\n"
                    "TLS ISSUERDN:      %s\n"
                    "TLS FINGERPRINT:   %s\n",
                state->server_connp.cert0_subject,
                state->server_connp.cert0_issuerdn,
                state->server_connp.cert0_fingerprint) < 0)
            goto end_fwrite_fpmeta;

        fclose(fpmeta);
    } else {
        if (logging_dir_not_writable < LOGGING_WRITE_ISSUE_LIMIT) {
            SCLogWarning(SC_ERR_FOPEN,
                         "Can't create meta file '%s' in '%s' directory",
                         filename, tls_logfile_base_dir);
            logging_dir_not_writable++;
        }
        SCReturn;
    }

    /* Reset the store flag */
    state->server_connp.cert_log_flag &= ~SSL_TLS_LOG_PEM;
    SCReturn;

end_fwrite_fp:
    fclose(fp);
    if (logging_dir_not_writable < LOGGING_WRITE_ISSUE_LIMIT) {
        SCLogWarning(SC_ERR_FWRITE, "Unable to write certificate");
        logging_dir_not_writable++;
    }
end_fwrite_fpmeta:
    if (fpmeta) {
        fclose(fpmeta);
        if (logging_dir_not_writable < LOGGING_WRITE_ISSUE_LIMIT) {
            SCLogWarning(SC_ERR_FWRITE, "Unable to write certificate metafile");
            logging_dir_not_writable++;
        }
    }
    SCReturn;
end_fp:
    fclose(fp);
    SCReturn;
}

/** \internal
 *  \brief Condition function for TLS logger
 *  \retval bool true or false -- log now?
 */
static int LogTlsStoreCondition(ThreadVars *tv, const Packet *p)
{
    if (p->flow == NULL) {
        return FALSE;
    }

    if (!(PKT_IS_TCP(p))) {
        return FALSE;
    }

    FLOWLOCK_RDLOCK(p->flow);
    uint16_t proto = FlowGetAppProtocol(p->flow);
    if (proto != ALPROTO_TLS)
        goto dontlog;

    SSLState *ssl_state = (SSLState *)FlowGetAppState(p->flow);
    if (ssl_state == NULL) {
        SCLogDebug("no tls state, so no request logging");
        goto dontlog;
    }

    /* we only log the state once if we don't have to write
     * the cert due to tls.store keyword. */
    if (!(ssl_state->server_connp.cert_log_flag & SSL_TLS_LOG_PEM) &&
        (ssl_state->flags & SSL_AL_FLAG_STATE_STORED))
        goto dontlog;

    if (ssl_state->server_connp.cert0_issuerdn == NULL ||
            ssl_state->server_connp.cert0_subject == NULL)
        goto dontlog;

    FLOWLOCK_UNLOCK(p->flow);
    return TRUE;
dontlog:
    FLOWLOCK_UNLOCK(p->flow);
    return FALSE;
}

static int LogTlsStoreLogger(ThreadVars *tv, void *thread_data, const Packet *p)
{
    LogTlsStoreLogThread *aft = (LogTlsStoreLogThread *)thread_data;
    int ipproto = (PKT_IS_IPV4(p)) ? AF_INET : AF_INET6;
    /* check if we have TLS state or not */
    FLOWLOCK_WRLOCK(p->flow);
    uint16_t proto = FlowGetAppProtocol(p->flow);
    if (proto != ALPROTO_TLS)
        goto end;

    SSLState *ssl_state = (SSLState *)FlowGetAppState(p->flow);
    if (unlikely(ssl_state == NULL)) {
        goto end;
    }

    if (ssl_state->server_connp.cert_log_flag & SSL_TLS_LOG_PEM) {
        LogTlsLogPem(aft, p, ssl_state, ipproto);
    }

    /* we only store the state once */
    ssl_state->flags |= SSL_AL_FLAG_STATE_STORED;
end:
    FLOWLOCK_UNLOCK(p->flow);
    return 0;
}

static TmEcode LogTlsStoreLogThreadInit(ThreadVars *t, void *initdata, void **data)
{
    LogTlsStoreLogThread *aft = SCMalloc(sizeof(LogTlsStoreLogThread));
    if (unlikely(aft == NULL))
        return TM_ECODE_FAILED;
    memset(aft, 0, sizeof(LogTlsStoreLogThread));

    if (initdata == NULL) {
        SCLogDebug("Error getting context for LogTlsStore. \"initdata\" argument NULL");
        SCFree(aft);
        return TM_ECODE_FAILED;
    }

    struct stat stat_buf;
    if (stat(tls_logfile_base_dir, &stat_buf) != 0) {
        int ret;
        ret = mkdir(tls_logfile_base_dir, S_IRWXU|S_IXGRP|S_IRGRP);
        if (ret != 0) {
            int err = errno;
            if (err != EEXIST) {
                SCLogError(SC_ERR_LOGDIR_CONFIG,
                        "Cannot create certs drop directory %s: %s",
                        tls_logfile_base_dir, strerror(err));
                exit(EXIT_FAILURE);
            }
        } else {
            SCLogInfo("Created certs drop directory %s",
                    tls_logfile_base_dir);
        }

    }

    *data = (void *)aft;
    return TM_ECODE_OK;
}

static TmEcode LogTlsStoreLogThreadDeinit(ThreadVars *t, void *data)
{
    LogTlsStoreLogThread *aft = (LogTlsStoreLogThread *)data;
    if (aft == NULL) {
        return TM_ECODE_OK;
    }

    /* clear memory */
    memset(aft, 0, sizeof(LogTlsStoreLogThread));

    SCFree(aft);
    return TM_ECODE_OK;
}

static void LogTlsStoreLogExitPrintStats(ThreadVars *tv, void *data)
{
    LogTlsStoreLogThread *aft = (LogTlsStoreLogThread *)data;
    if (aft == NULL) {
        return;
    }

    SCLogInfo("(%s) certificates extracted %" PRIu32 "", tv->name, aft->tls_cnt);
}

/**
 *  \internal
 *
 *  \brief deinit the log ctx and write out the waldo
 *
 *  \param output_ctx output context to deinit
 */
static void LogTlsStoreLogDeInitCtx(OutputCtx *output_ctx)
{
    SCFree(output_ctx);
}

/** \brief Create a new http log LogFilestoreCtx.
 *  \param conf Pointer to ConfNode containing this loggers configuration.
 *  \return NULL if failure, LogFilestoreCtx* to the file_ctx if succesful
 * */
static OutputCtx *LogTlsStoreLogInitCtx(ConfNode *conf)
{

    OutputCtx *output_ctx = SCCalloc(1, sizeof(OutputCtx));
    if (unlikely(output_ctx == NULL))
        return NULL;

    output_ctx->data = NULL;
    output_ctx->DeInit = LogTlsStoreLogDeInitCtx;

    /* FIXME we need to implement backward compability here */
    char *s_default_log_dir = NULL;
    s_default_log_dir = ConfigGetLogDirectory();

    const char *s_base_dir = NULL;
    s_base_dir = ConfNodeLookupChildValue(conf, "certs-log-dir");
    if (s_base_dir == NULL || strlen(s_base_dir) == 0) {
        strlcpy(tls_logfile_base_dir,
                s_default_log_dir, sizeof(tls_logfile_base_dir));
    } else {
        if (PathIsAbsolute(s_base_dir)) {
            strlcpy(tls_logfile_base_dir,
                    s_base_dir, sizeof(tls_logfile_base_dir));
        } else {
            snprintf(tls_logfile_base_dir, sizeof(tls_logfile_base_dir),
                    "%s/%s", s_default_log_dir, s_base_dir);
        }
    }

    SCLogInfo("storing certs in %s", tls_logfile_base_dir);

    SCReturnPtr(output_ctx, "OutputCtx");
}

void TmModuleLogTlsStoreRegister (void)
{
    tmm_modules[TMM_TLSSTORE].name = MODULE_NAME;
    tmm_modules[TMM_TLSSTORE].ThreadInit = LogTlsStoreLogThreadInit;
    tmm_modules[TMM_TLSSTORE].Func = NULL;
    tmm_modules[TMM_TLSSTORE].ThreadExitPrintStats = LogTlsStoreLogExitPrintStats;
    tmm_modules[TMM_TLSSTORE].ThreadDeinit = LogTlsStoreLogThreadDeinit;
    tmm_modules[TMM_TLSSTORE].RegisterTests = NULL;
    tmm_modules[TMM_TLSSTORE].cap_flags = 0;
    tmm_modules[TMM_TLSSTORE].flags = TM_FLAG_LOGAPI_TM;
    tmm_modules[TMM_TLSSTORE].priority = 10;

    OutputRegisterPacketModule(MODULE_NAME, "tls-store", LogTlsStoreLogInitCtx,
            LogTlsStoreLogger, LogTlsStoreCondition);

    SC_ATOMIC_INIT(cert_id);

    SCLogDebug("registered");
}
