// Kernel.cpp (Oclgrind)
// Copyright (c) 2013-2015, James Price and Simon McIntosh-Smith,
// University of Bristol. All rights reserved.
//
// This program is provided under a three-clause BSD license. For full
// license terms please see the LICENSE file distributed with this
// source code.

#include "common.h"
#include <sstream>

#include "llvm/IR/Constants.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_os_ostream.h"

#include "Kernel.h"
#include "Program.h"
#include "Memory.h"

using namespace oclgrind;
using namespace std;

Kernel::Kernel(const Program *program,
               const llvm::Function *function, const llvm::Module *module)
 : m_program(program), m_function(function), m_name(function->getName())
{
  m_localMemory = new Memory(AddrSpaceLocal, program->getContext());
  m_privateMemory = new Memory(AddrSpacePrivate, program->getContext());

  // Set-up global variables
  llvm::Module::const_global_iterator itr;
  for (itr = module->global_begin(); itr != module->global_end(); itr++)
  {
    llvm::PointerType *type = itr->getType();
    switch (type->getPointerAddressSpace())
    {
    case AddrSpacePrivate:
    {
      const llvm::Constant *init = itr->getInitializer();

      // Allocate private memory for variable
      unsigned size = getTypeSize(init->getType());
      size_t address = m_privateMemory->allocateBuffer(size);

      // Initialize variable
      void *ptr = m_privateMemory->getPointer(address);
      getConstantData((unsigned char*)ptr, init);

      TypedValue value =
      {
        sizeof(size_t),
        1,
        new unsigned char[sizeof(size_t)]
      };
      value.setPointer(address);
      m_arguments[itr] = value;

      break;
    }
    case AddrSpaceConstant:
      m_constants.push_back(itr);
      break;
    case AddrSpaceLocal:
    {
      // Allocate buffer
      unsigned size = getTypeSize(itr->getInitializer()->getType());
      TypedValue v = {
        sizeof(size_t),
        1,
        new unsigned char[sizeof(size_t)]
      };
      v.setPointer(m_localMemory->allocateBuffer(size));
      m_arguments[itr] = v;

      break;
    }
    default:
      FATAL_ERROR("Unsupported GlobalVariable address space: %d",
                  type->getPointerAddressSpace());
    }
  }

  // Get metadata node containing kernel arg info
  m_metadata = NULL;
  llvm::NamedMDNode *md = module->getNamedMetadata("opencl.kernels");
  if (md)
  {
    for (unsigned i = 0; i < md->getNumOperands(); i++)
    {
      llvm::MDNode *node = md->getOperand(i);

      llvm::ConstantAsMetadata *cam =
        llvm::dyn_cast<llvm::ConstantAsMetadata>(node->getOperand(0).get());
      if (!cam)
        continue;

      llvm::Function *function = ((llvm::Function*)cam->getValue());
      if (function->getName() == m_name)
      {
        m_metadata = node;
        break;
      }
    }
  }
}

Kernel::Kernel(const Kernel& kernel)
 : m_program(kernel.m_program)
{
  m_function = kernel.m_function;
  m_constants = kernel.m_constants;
  m_constantBuffers = kernel.m_constantBuffers;
  m_localMemory = kernel.m_localMemory->clone();
  m_privateMemory = kernel.m_privateMemory->clone();
  m_name = kernel.m_name;
  m_metadata = kernel.m_metadata;

  TypedValueMap::const_iterator itr;
  for (itr = kernel.m_arguments.begin();
       itr != kernel.m_arguments.end(); itr++)
  {
    m_arguments[itr->first] = itr->second.clone();
  }
}

Kernel::~Kernel()
{
  delete m_localMemory;
  delete m_privateMemory;

  TypedValueMap::iterator itr;
  for (itr = m_arguments.begin(); itr != m_arguments.end(); itr++)
  {
    delete[] itr->second.data;
  }
}

bool Kernel::allArgumentsSet() const
{
  llvm::Function::const_arg_iterator itr;
  for (itr = m_function->arg_begin(); itr != m_function->arg_end(); itr++)
  {
    if (!m_arguments.count(itr))
    {
      return false;
    }
  }
  return true;
}

void Kernel::allocateConstants(Memory *memory)
{
  list<const llvm::GlobalVariable*>::const_iterator itr;
  for (itr = m_constants.begin(); itr != m_constants.end(); itr++)
  {
    const llvm::Constant *initializer = (*itr)->getInitializer();
    const llvm::Type *type = initializer->getType();

    // Allocate buffer
    unsigned size = getTypeSize(type);
    TypedValue v = {
      sizeof(size_t),
      1,
      new unsigned char[sizeof(size_t)]
    };
    size_t address = memory->allocateBuffer(size);
    v.setPointer(address);
    m_constantBuffers.push_back(address);
    m_arguments[*itr] = v;

    // Initialise buffer contents
    unsigned char *data = new unsigned char[size];
    getConstantData(data, (const llvm::Constant*)initializer);
    memory->store(data, address, size);
    delete[] data;
  }
}

void Kernel::deallocateConstants(Memory *memory)
{
  list<size_t>::const_iterator itr;
  for (itr = m_constantBuffers.begin(); itr != m_constantBuffers.end(); itr++)
  {
    memory->deallocateBuffer(*itr);
  }
  m_constantBuffers.clear();
}

const llvm::Argument* Kernel::getArgument(unsigned int index) const
{
  assert(index < getNumArguments());

  llvm::Function::const_arg_iterator argItr = m_function->arg_begin();
  for (unsigned i = 0; i < index; i++)
  {
    argItr++;
  }
  return argItr;
}

unsigned int Kernel::getArgumentAccessQualifier(unsigned int index) const
{
  assert(index < getNumArguments());

  // Get metadata node
  const llvm::MDNode *node = getArgumentMetadata("kernel_arg_access_qual");
  if (!node)
  {
    return -1;
  }

  // Get qualifier string
  llvm::MDString *str
    = llvm::dyn_cast<llvm::MDString>(node->getOperand(index+1));
  string access = str->getString();
  if (access == "read_only")
  {
    return CL_KERNEL_ARG_ACCESS_READ_ONLY;
  }
  else if (access == "write_only")
  {
    return CL_KERNEL_ARG_ACCESS_WRITE_ONLY;
  }
  else if (access == "read_write")
  {
    return CL_KERNEL_ARG_ACCESS_READ_WRITE;
  }
  return CL_KERNEL_ARG_ACCESS_NONE;
}

unsigned int Kernel::getArgumentAddressQualifier(unsigned int index) const
{
  assert(index < getNumArguments());

  // Get metadata node
  const llvm::MDNode *node = getArgumentMetadata("kernel_arg_addr_space");
  if (!node)
  {
    return -1;
  }

  // Get address space
  switch(getMDOpAsConstInt(node->getOperand(index+1))->getZExtValue())
  {
    case AddrSpacePrivate:
      return CL_KERNEL_ARG_ADDRESS_PRIVATE;
    case AddrSpaceGlobal:
      return CL_KERNEL_ARG_ADDRESS_GLOBAL;
    case AddrSpaceConstant:
      return CL_KERNEL_ARG_ADDRESS_CONSTANT;
    case AddrSpaceLocal:
      return CL_KERNEL_ARG_ADDRESS_LOCAL;
    default:
      return -1;
  }
}

const llvm::MDNode* Kernel::getArgumentMetadata(string name) const
{
  if (!m_metadata)
  {
    return NULL;
  }

  // Loop over all metadata nodes for this kernel
  for (unsigned i = 0; i < m_metadata->getNumOperands(); i++)
  {
    const llvm::MDOperand& op = m_metadata->getOperand(i);
    if (llvm::MDNode *node = llvm::dyn_cast<llvm::MDNode>(op.get()))
    {
      // Check if node matches target name
      if (node->getNumOperands() > 0 &&
          ((llvm::MDString*)(node->getOperand(0).get()))->getString() == name)
      {
        return node;
      }
    }
  }
  return NULL;
}

const llvm::StringRef Kernel::getArgumentName(unsigned int index) const
{
  return getArgument(index)->getName();
}

const llvm::StringRef Kernel::getArgumentTypeName(unsigned int index) const
{
  assert(index < getNumArguments());

  // Get metadata node
  const llvm::MDNode *node = getArgumentMetadata("kernel_arg_type");
  if (!node)
  {
    return "";
  }

  return llvm::dyn_cast<llvm::MDString>(node->getOperand(index+1))->getString();
}

unsigned int Kernel::getArgumentTypeQualifier(unsigned int index) const
{
  assert(index < getNumArguments());

  // Get metadata node
  const llvm::MDNode *node = getArgumentMetadata("kernel_arg_type_qual");
  if (!node)
  {
    return -1;
  }

  // Get qualifiers
  llvm::MDString *str =
    llvm::dyn_cast<llvm::MDString>(node->getOperand(index+1));
  istringstream iss(str->getString().str());

  unsigned int result = CL_KERNEL_ARG_TYPE_NONE;
  while (!iss.eof())
  {
    string tok;
    iss >> tok;
    if (tok == "const")
    {
      result |= CL_KERNEL_ARG_TYPE_CONST;
    }
    else if (tok == "restrict")
    {
      result |= CL_KERNEL_ARG_TYPE_RESTRICT;
    }
    else if (tok == "volatile")
    {
      result |= CL_KERNEL_ARG_TYPE_VOLATILE;
    }
  }

  return result;
}

size_t Kernel::getArgumentSize(unsigned int index) const
{
  const llvm::Argument *argument = getArgument(index);
  const llvm::Type *type = argument->getType();

  // Check if pointer argument
  if (type->isPointerTy() && argument->hasByValAttr())
  {
    return getTypeSize(type->getPointerElementType());
  }

  return getTypeSize(type);
}

string Kernel::getAttributes() const
{
  ostringstream attributes("");
  for (unsigned i = 0; i < m_metadata->getNumOperands(); i++)
  {
    llvm::MDNode *op = llvm::dyn_cast<llvm::MDNode>(m_metadata->getOperand(i));
    if (op)
    {
      llvm::MDNode *val = ((llvm::MDNode*)op);
      llvm::MDString *str =
        llvm::dyn_cast<llvm::MDString>(val->getOperand(0).get());
      string name = str->getString().str();

      if (name == "reqd_work_group_size" ||
          name == "work_group_size_hint")
      {
        attributes << name << "("
                   <<
          getMDOpAsConstInt(val->getOperand(1))->getZExtValue()
                   << "," <<
          getMDOpAsConstInt(val->getOperand(2))->getZExtValue()
                   << "," <<
          getMDOpAsConstInt(val->getOperand(3))->getZExtValue()
                   << ") ";
      }
      else if (name == "vec_type_hint")
      {
        // Get type hint
        size_t n = 1;
        llvm::Metadata *md = val->getOperand(1).get();
        llvm::ValueAsMetadata *vam = llvm::dyn_cast<llvm::ValueAsMetadata>(md);
        const llvm::Type *type = vam->getType();
        if (type->isVectorTy())
        {
          n = type->getVectorNumElements();
          type = type->getVectorElementType();
        }

        // Generate attribute string
        attributes << name << "(" << flush;
        llvm::raw_os_ostream out(attributes);
        type->print(out);
        out.flush();
        attributes << n << ") ";
      }
    }
  }
  return attributes.str();
}

const llvm::Function* Kernel::getFunction() const
{
  return m_function;
}

const Memory* Kernel::getLocalMemory() const
{
  return m_localMemory;
}

size_t Kernel::getLocalMemorySize() const
{
  return m_localMemory->getTotalAllocated();
}

const std::string& Kernel::getName() const
{
  return m_name;
}

unsigned int Kernel::getNumArguments() const
{
  return m_function->arg_size();
}

const Memory* Kernel::getPrivateMemory() const
{
  return m_privateMemory;
}

const Program* Kernel::getProgram() const
{
  return m_program;
}

void Kernel::getRequiredWorkGroupSize(size_t reqdWorkGroupSize[3]) const
{
  memset(reqdWorkGroupSize, 0, 3*sizeof(size_t));
  for (unsigned i = 0; i < m_metadata->getNumOperands(); i++)
  {
    const llvm::MDOperand& op = m_metadata->getOperand(i);
    if (llvm::MDNode *val = llvm::dyn_cast<llvm::MDNode>(op.get()))
    {
      llvm::MDString *str =
        llvm::dyn_cast<llvm::MDString>(val->getOperand(0).get());
      if (str->getString() == "reqd_work_group_size")
      {
        for (int j = 0; j < 3; j++)
        {
          reqdWorkGroupSize[j] =
            getMDOpAsConstInt(val->getOperand(j+1))->getZExtValue();
        }
      }
    }
  }
}

void Kernel::setArgument(unsigned int index, TypedValue value)
{
  assert(index < m_function->arg_size());

  const llvm::Value *argument = getArgument(index);
  unsigned int type = getArgumentAddressQualifier(index);
  if (type == CL_KERNEL_ARG_ADDRESS_LOCAL)
  {
    // Deallocate existing argument
    if (m_arguments.count(argument))
    {
      m_localMemory->deallocateBuffer(m_arguments[argument].getPointer());
      delete[] m_arguments[argument].data;
    }

    // Allocate local memory buffer
    TypedValue v = {
      sizeof(size_t),
      1,
      new unsigned char[sizeof(size_t)]
    };
    v.setPointer(m_localMemory->allocateBuffer(value.size));
    m_arguments[argument] = v;
  }
  else
  {
    if (((const llvm::Argument*)argument)->hasByValAttr())
    {
      // Deallocate existing argument
      if (m_arguments.count(argument))
      {
        m_privateMemory->deallocateBuffer(m_arguments[argument].getPointer());
        delete[] m_arguments[argument].data;
      }

      TypedValue address =
      {
        sizeof(size_t),
        1,
        new unsigned char[sizeof(size_t)]
      };
      size_t size = value.size*value.num;
      address.setPointer(m_privateMemory->allocateBuffer(size));
      m_privateMemory->store(value.data, address.getPointer(), size);
      m_arguments[argument] = address;
    }
    else
    {
      // Deallocate existing argument
      if (m_arguments.count(argument))
      {
        delete[] m_arguments[argument].data;
      }

      const llvm::Type *type = argument->getType();
      if (type->isVectorTy())
      {
        value.num = type->getVectorNumElements();
        value.size = getTypeSize(type->getVectorElementType());
      }
      m_arguments[argument] = value.clone();
    }
  }
}

TypedValueMap::const_iterator Kernel::args_begin() const
{
  return m_arguments.begin();
}

TypedValueMap::const_iterator Kernel::args_end() const
{
  return m_arguments.end();
}
