#include <iostream>
#include "Common.h"
#include "PrimitiveTests.h"
#include <random>

Histogram::Histogram(bool _global, int _valueSet, int _data_size)
{
	global = _global;
	valueSet = _valueSet;
	data_size = _data_size;

	std::random_device rd;
	std::mt19937 gen(rd());
	std::uniform_int_distribution<int> distr(0, valueSet-1);

	for (size_t index = 0; index < data_size; ++index) {
		sourceData.push_back(distr(gen));
	}
	gpuResult.resize(valueSet, 0);
}

void Histogram::collect_results(cl::CommandQueue* queue)
{
	queue->enqueueReadBuffer(clResultBuffer, true, 0, sizeof(int) * valueSet, gpuResult.data());
}

void Histogram::gpu_compute(cl::Context* context, cl::CommandQueue* queue, cl::Program* program, cl::Event* Event)
{
	cl_int err = CL_SUCCESS;

	// Get the kernel handle
	cl::Kernel kernel;
	if (global) {
		kernel = cl::Kernel(*program, "histogram_global", &err);
	}
	else {
		kernel = cl::Kernel(*program, "histogram_local", &err);
	}	
	CheckCLError(err);

	clInputBuffer = cl::Buffer(*context, CL_MEM_READ_ONLY, sizeof(int) * data_size, NULL, &err);
	queue->enqueueWriteBuffer(clInputBuffer,
		true, // Blocking!
		0, sizeof(int) * data_size, sourceData.data());
	CheckCLError(err);
	// Allocate the output data
	clResultBuffer = cl::Buffer(*context, CL_MEM_WRITE_ONLY, sizeof(int) * valueSet, NULL, &err);
	CheckCLError(err);
	// Set the kernel parameters	
	kernel.setArg(0, clInputBuffer); // kernel FV param�terei sorrendben
	kernel.setArg(1, clResultBuffer);
	if (!global) {
		kernel.setArg(2, sizeof(int) * valueSet, NULL);
		kernel.setArg(3, valueSet);
	}

	// Enqueue the kernel
	queue->enqueueNDRangeKernel(kernel,
		cl::NullRange,				// Indexek nem eloffszetelve
		cl::NDRange(data_size, 1),	// Minden elemet egy sz�l
		cl::NullRange,				// Workgroup m�ret? - ez az auto, ha nem indul, 1024-re, onnan cs�kkent, amig elindul
		NULL,						// 
		Event);
}

void Histogram::cpu_compute()
{
	cpuResult.resize(0, 0);
	cpuResult.resize(valueSet, 0);
	for (size_t index = 0; index < data_size; ++index) {
		cpuResult[sourceData[index]] = cpuResult[sourceData[index]] + 1;
	}
}

bool Histogram::validate_results()
{
	for (size_t index = 0; index < valueSet; index++) {
		if (cpuResult[index] != gpuResult[index]) {
			std::cout << "Wrong result at [" << index << "]: " << gpuResult[index] << "!=" << cpuResult[index] << std::endl;
			return false;
		}
	}
	return true;
}

std::string Histogram::description()
{
	std::string type;
	if (global) {
		type = "gobal";
	}
	else {
		type = "local";
	}
	return std::string("Histogram (type=" + type +",data_size=" + std::to_string(data_size) + ",valueSet=" + std::to_string(valueSet) + ")");
}