///////////////////////////////////////////////////////////////////////////////
//
//  Copyright (2008) Alexander Stukowski
//
//  This file is part of OVITO (Open Visualization Tool).
//
//  OVITO is free software; you can redistribute it and/or modify
//  it under the terms of the GNU General Public License as published by
//  the Free Software Foundation; either version 2 of the License, or
//  (at your option) any later version.
//
//  OVITO is distributed in the hope that it will be useful,
//  but WITHOUT ANY WARRANTY; without even the implied warranty of
//  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
//  GNU General Public License for more details.
//
//  You should have received a copy of the GNU General Public License
//  along with this program.  If not, see <http://www.gnu.org/licenses/>.
//
///////////////////////////////////////////////////////////////////////////////

#include <core/Core.h>
#include <core/data/ObjectLoadStream.h>
#include <core/data/ObjectSaveStream.h>
#include <core/data/units/ParameterUnit.h>
#include <core/viewport/Viewport.h>
#include <core/scene/ObjectNode.h>
#include "AtomsObject.h"
#include "AtomsObjectEditor.h"
#include "datachannels/PositionDataChannel.h"
#include "datachannels/AtomTypeDataChannel.h"
#include "datachannels/DisplacementDataChannel.h"
#include "datachannels/OrientationDataChannel.h"
#include "datachannels/DeformationGradientDataChannel.h"

namespace AtomViz {

IMPLEMENT_SERIALIZABLE_PLUGIN_CLASS(AtomsObject, SceneObject)
DEFINE_VECTOR_REFERENCE_FIELD(AtomsObject, DataChannel, "DataChannels", _dataChannels)
DEFINE_FLAGS_REFERENCE_FIELD(AtomsObject, SimulationCell, "SimulationCell", PROPERTY_FIELD_ALWAYS_DEEP_COPY, _simulationCell)
DEFINE_PROPERTY_FIELD(AtomsObject, "SerializeAtoms", _serializeAtoms)
SET_PROPERTY_FIELD_LABEL(AtomsObject, _dataChannels, "Data Channels")
SET_PROPERTY_FIELD_LABEL(AtomsObject, _simulationCell, "Simulation Cell")
SET_PROPERTY_FIELD_LABEL(AtomsObject, _serializeAtoms, "Serialize atoms")

/******************************************************************************
* Constructs an atoms object.
******************************************************************************/
AtomsObject::AtomsObject(bool isLoading) : SceneObject(isLoading),
	numAtoms(0), sceneBoundingBoxValidity(TimeNever), _serializeAtoms(true)
{
	INIT_PROPERTY_FIELD(AtomsObject, _dataChannels);
	INIT_PROPERTY_FIELD(AtomsObject, _simulationCell);
	INIT_PROPERTY_FIELD(AtomsObject, _serializeAtoms);
	if(!isLoading) {
		// Create the simulation cell.
		setSimulationCell(new SimulationCell());
	}
}


/******************************************************************************
* Asks the object for its validity interval at the given time.
******************************************************************************/
TimeInterval AtomsObject::objectValidity(TimeTicks time)
{
	TimeInterval interval = TimeForever;
	return interval;
}

/******************************************************************************
* Makes the object render itself into the viewport.
******************************************************************************/
void AtomsObject::renderObject(TimeTicks time, ObjectNode* contextNode, Viewport* vp)
{
	CHECK_OBJECT_POINTER(simulationCell());

	// Enable z-Buffer.
	vp->setDepthTest(true);

	// Render the borders of the simulation cell.
	simulationCell()->render(time, vp, contextNode);

	// Do not include atoms in the picking test.
	if(vp->isPicking()) return;

	// Render data channels.
	Q_FOREACH(DataChannel* channel, dataChannels()) {
		if(channel->isVisible())
			channel->render(time, vp, this, contextNode);
	}
}

/******************************************************************************
* Renders the object in preview rendering mode using OpenGL.
******************************************************************************/
bool AtomsObject::renderPreview(TimeTicks time, const CameraViewDescription& view, ObjectNode* contextNode, int imageWidth, int imageHeight, Window3D* glcontext)
{
	// Render the borders of the simulation cell.
	CHECK_OBJECT_POINTER(simulationCell());
	simulationCell()->renderHQ(time, view, contextNode, imageWidth, imageHeight, glcontext);

	// Render data channels.
	Q_FOREACH(DataChannel* channel, dataChannels()) {
		if(channel->isVisible())
			channel->renderHQ(time, this, view, contextNode, imageWidth, imageHeight, glcontext);
	}

	return true;
}

/******************************************************************************
* Returns the bounding box of the object in local object coordinates.
******************************************************************************/
Box3 AtomsObject::boundingBox(TimeTicks time, ObjectNode* contextNode)
{
	if(!sceneBoundingBoxValidity.contains(time)) {
		sceneBoundingBoxValidity = TimeForever;

		// Start with the bounding box of the simulation cell.
		if(simulationCell()) {
			sceneBoundingBox = simulationCell()->boundingBox();
			sceneBoundingBox = sceneBoundingBox.padBox(simulationCell()->simulationCellLineWidth() * 0.5);
		}

		// Add the bounding boxes of the data channels.
		Q_FOREACH(DataChannel* channel, dataChannels()) {
			if(channel->isVisible())
				sceneBoundingBox.addBox(channel->boundingBox(time, this, contextNode, sceneBoundingBoxValidity));
		}
	}
	return sceneBoundingBox;
}

/******************************************************************************
* Inserts a new data channel of the given type.
******************************************************************************/
DataChannel* AtomsObject::createCustomDataChannel(int dataChannelType, size_t dataTypeSize, size_t componentCount)
{
	DataChannel::SmartPtr channel = new DataChannel(dataChannelType, dataTypeSize, componentCount);

	// Take array sizes from the existing data channels.
	channel->resize(atomsCount());

	// Insert into array.
	insertDataChannel(channel);
	OVITO_ASSERT(channel->channelUsageCount() == 1);

	return channel.get();
}

/******************************************************************************
* Creates a standard data channel.
******************************************************************************/
DataChannel* AtomsObject::createStandardDataChannel(DataChannel::DataChannelIdentifier which)
{
	DataChannel::SmartPtr channel = getStandardDataChannel(which);

	if(channel == NULL) {

		switch(which) {
			case DataChannel::PositionChannel:
				channel = new PositionDataChannel(which);
				break;
			case DataChannel::AtomTypeChannel:
			case DataChannel::CNATypeChannel:
				channel = new AtomTypeDataChannel(which);
				break;
			case DataChannel::DisplacementChannel:
				channel = new DisplacementDataChannel(which);
				break;
			case DataChannel::OrientationChannel:
				channel = new OrientationDataChannel(which);
				break;
			case DataChannel::DeformationGradientChannel:
				channel = new DeformationGradientDataChannel(which);
				break;
			default:
				channel = new DataChannel(which);
		}

		// Take array sizes from the existing data channels.
		channel->resize(atomsCount());

		// Insert into array.
		insertDataChannel(channel);
		OVITO_ASSERT(channel->channelUsageCount() == 1);

		// Initialize color channel with default colors.
		if(which == DataChannel::ColorChannel) {
			Vector3* c = channel->dataVector3();
			for(size_t i = channel->size(); i != 0; i--)
				*c++ = Vector3(1.0);
		}
	}
	else {
		OVITO_ASSERT_MSG(channel->type() == DataChannel::standardChannelType(which), "AtomsObject::createStandardDataChannel",
				"The data type of the extsing standard data channel is invalid.");
	}

	return channel.get();
}

/******************************************************************************
* Inserts a data channel.
******************************************************************************/
void AtomsObject::insertDataChannel(DataChannel* newChannel)
{
	CHECK_OBJECT_POINTER(newChannel);

	// Make sure that the channel has the correct size.
	OVITO_ASSERT_MSG(newChannel->size() == atomsCount(), "AtomsObject::insertDataChannel()", "The new data channel must have the correct size, i.e. the same number of atoms as the AtomsObject.");

	if(newChannel) {

		// Make sure that the new channel is an instance of the correct DataChannel sub-class.
		OVITO_ASSERT_MSG(newChannel->id() != DataChannel::PositionChannel || dynamic_object_cast<PositionDataChannel>(newChannel) != NULL, "AtomsObject::insertDataChannel()", "The new data channel being inserted into the AtomsObject has a wrong class type.");
		OVITO_ASSERT_MSG(newChannel->id() != DataChannel::AtomTypeChannel || dynamic_object_cast<AtomTypeDataChannel>(newChannel) != NULL, "AtomsObject::insertDataChannel()", "The new data channel being inserted into the AtomsObject has a wrong class type.");
		OVITO_ASSERT_MSG(newChannel->id() != DataChannel::CNATypeChannel || dynamic_object_cast<AtomTypeDataChannel>(newChannel) != NULL, "AtomsObject::insertDataChannel()", "The new data channel being inserted into the AtomsObject has a wrong class type.");
		OVITO_ASSERT_MSG(newChannel->id() != DataChannel::DisplacementChannel || dynamic_object_cast<DisplacementDataChannel>(newChannel) != NULL, "AtomsObject::insertDataChannel()", "The new data channel being inserted into the AtomsObject has a wrong class type.");
		OVITO_ASSERT_MSG(newChannel->id() != DataChannel::OrientationChannel || dynamic_object_cast<OrientationDataChannel>(newChannel) != NULL, "AtomsObject::insertDataChannel()", "The new data channel being inserted into the AtomsObject has a wrong class type.");
		OVITO_ASSERT_MSG(newChannel->id() != DataChannel::DeformationGradientChannel || dynamic_object_cast<DeformationGradientDataChannel>(newChannel) != NULL, "AtomsObject::insertDataChannel()", "The new data channel being inserted into the AtomsObject has a wrong class type.");

		// Check for existing channel.
		if(newChannel->id() != DataChannel::UserDataChannel) {

			DataChannel* existingChannel = getStandardDataChannel(newChannel->id());
			if(existingChannel) {
				replaceDataChannel(existingChannel, newChannel);
				return;
			}
		}

		// Adopt the serialization flag.
		if(newChannel->channelUsageCount() == 0)
			newChannel->setSerializeData(serializeAtoms());

		// Insert new channel into array.
		_dataChannels.push_back(newChannel);
	}
}

/******************************************************************************
* Returns the first data channel with the given name or NULL if there is
* no data channel with the given name.
******************************************************************************/
DataChannel* AtomsObject::findDataChannelByName(const QString& name) const
{
	Q_FOREACH(DataChannel* channel, dataChannels()) {
		if(name == channel->name())
			return channel;
	}
	return NULL;
}

/******************************************************************************
* Looks up a DataChannel based on an offline reference.
******************************************************************************/
DataChannel* AtomsObject::lookupDataChannel(const DataChannelReference& ref) const
{
	if(ref.id() != DataChannel::UserDataChannel)
		return getStandardDataChannel(ref.id());
	else
		return findDataChannelByName(ref.name());
}

/******************************************************************************
* Returns a standard data channel with the given identifier.
******************************************************************************/
DataChannel* AtomsObject::getStandardDataChannel(DataChannel::DataChannelIdentifier identifier) const
{
	Q_FOREACH(DataChannel* channel, dataChannels()) {
		if(identifier == channel->id())
			return channel;
	}
	return NULL;
}

/******************************************************************************
* Turns a shallow copy of a data channel into a real copy.
******************************************************************************/
DataChannel* AtomsObject::copyShallowChannel(DataChannel* channel)
{
	if(!channel) return NULL;

	OVITO_ASSERT_MSG(dataChannels().contains(channel), "AtomsObject::copyShallowChannel()", "The specified data channel is not referenced by the AtomsObject.");

	// Check if the channel is exclusively used by this AtomsObject.
	if(channel->channelUsageCount() <= 1)
		return channel;

	DataChannel::SmartPtr channelClone;
	{
		UndoSuspender undoSuspender;	// Do not create undo records for the copy operation.
		CloneHelper cloneHelper;
		channelClone = cloneHelper.cloneObject(channel, true);
		CHECK_OBJECT_POINTER(channelClone);
	}
	replaceDataChannel(channel, channelClone.get());
	OVITO_ASSERT(channelClone->channelUsageCount() == 1);

	return channelClone.get();
}

/******************************************************************************
* Replaces a data channel in the atoms object with a new one.
******************************************************************************/
void AtomsObject::replaceDataChannel(DataChannel* oldChannel, const DataChannel::SmartPtr& newChannel)
{
	if(oldChannel == newChannel) return;

	CHECK_OBJECT_POINTER(oldChannel);
	CHECK_OBJECT_POINTER(newChannel);
	OVITO_ASSERT_MSG(newChannel->size() == atomsCount(), "AtomsObject::replaceDataChannel()", "The new data channel does not have the correct size.");

	// Adopt the serialization flag.
	if(newChannel->channelUsageCount() == 0)
		newChannel->setSerializeData(serializeAtoms());

	replaceReferencesTo(oldChannel, newChannel);
}

/******************************************************************************
* From RefMaker.
* This method is called when a reference target changes.
******************************************************************************/
bool AtomsObject::onRefTargetMessage(RefTarget* source, RefTargetMessage* msg)
{
	if(msg->type() == REFTARGET_CHANGED) {
		invalidate();
	}
	return SceneObject::onRefTargetMessage(source, msg);
}

/******************************************************************************
* From RefMaker.
******************************************************************************/
void AtomsObject::onRefTargetInserted(const PropertyFieldDescriptor& field, RefTarget* newTarget, int listIndex)
{
	invalidate();
	SceneObject::onRefTargetInserted(field, newTarget, listIndex);
}

/******************************************************************************
* From RefMaker.
******************************************************************************/
void AtomsObject::onRefTargetRemoved(const PropertyFieldDescriptor& field, RefTarget* oldTarget, int listIndex)
{
	invalidate();
	SceneObject::onRefTargetRemoved(field, oldTarget, listIndex);
}

/******************************************************************************
* Saves the class' contents to the given stream.
******************************************************************************/
void AtomsObject::saveToStream(ObjectSaveStream& stream)
{
	SceneObject::saveToStream(stream);
	stream.beginChunk(0x10000000);
	if(serializeAtoms())
		stream.writeSizeT(numAtoms);
	else
		stream.writeSizeT(0);
	stream.endChunk();
}

/******************************************************************************
* Loads the class' contents from the given stream.
******************************************************************************/
void AtomsObject::loadFromStream(ObjectLoadStream& stream)
{
	SceneObject::loadFromStream(stream);
	stream.expectChunk(0x10000000);
	stream.readSizeT(numAtoms);
	stream.closeChunk();
}

/******************************************************************************
* Resizes the atoms array. Returns the old atom count.
******************************************************************************/
size_t AtomsObject::setAtomsCount(size_t newAtomCount)
{
	size_t oldSize = atomsCount();
	if(newAtomCount == oldSize) return oldSize;

	// Make a deep copy of all channels first before they are resized.
	Q_FOREACH(DataChannel* channel, dataChannels()) {
		copyShallowChannel(channel);
	}

	// Resize data array in each data channel.
	this->numAtoms = newAtomCount;
	Q_FOREACH(DataChannel* channel, dataChannels()) {

		OVITO_ASSERT(channel->channelUsageCount() == 1);
		OVITO_ASSERT_MSG(oldSize == channel->size(), "AtomsObject::setAtomsCount()", "Data channel sizes are out of sync.");
		if(oldSize != channel->size())
			throw Exception("Data channel sizes are out of sync.");

		channel->resize(newAtomCount);
		OVITO_ASSERT(newAtomCount == channel->size());
	}

	invalidate();

	return oldSize;
}

/******************************************************************************
* Sets whether atomic coordinates are saved along with the scene.
******************************************************************************/
void AtomsObject::setSerializeAtoms(bool on)
{
	_serializeAtoms = on;
	Q_FOREACH(DataChannel* channel, dataChannels()) {
		channel->setSerializeData(on);
	}
}

/******************************************************************************
* Invalidates all temporary information stored with the AtomsObject.
* This method must be called each time the atoms have been changed in some
* way.
******************************************************************************/
void AtomsObject::invalidate()
{
	sceneBoundingBoxValidity.setEmpty();

	for(DataChannelList::const_iterator c = dataChannels().constBegin(); c != dataChannels().constEnd(); ++c)
		(*c)->clearCaches();

	notifyDependents(REFTARGET_CHANGED);
}

/**
 * This helper class is used by AtomsObject::deleteAtoms() to split up the data channel copying into smaller
 * operations that can be performed on multiple processors in parallel.
 */
class DeleteAtomsKernel {
public:
	// Constructor that takes references to the filter mask array.
	DeleteAtomsKernel(const dynamic_bitset<>& _mask) : mask(_mask) {}
	// The actual kernel function that is called by the Qt concurrent framework for each data cahnnel.
	void operator()(const QPair<DataChannel*, DataChannel*> channelPair) {
		channelPair.second->filterCopy(channelPair.first, mask);
	}
private:
	const dynamic_bitset<>& mask;
};

/******************************************************************************
* Deletes the atoms given by the bitmask from the object.
* Returns the number of remaining atoms.
******************************************************************************/
size_t AtomsObject::deleteAtoms(const dynamic_bitset<>& mask)
{
	OVITO_ASSERT(mask.size() == atomsCount());
	size_t oldAtomsCount = atomsCount();
	size_t newAtomsCount = oldAtomsCount - mask.count();
	if(newAtomsCount == oldAtomsCount)
		return oldAtomsCount;	// Nothing to delete.

	CloneHelper cloneHelper;
	numAtoms = newAtomsCount;

	// Contains for each data channel the old and the new version.
	QVector< QPair<DataChannel*, DataChannel*> > oldToNewMap;
	oldToNewMap.reserve(dataChannels().size());

	// Allocate new data channels
	Q_FOREACH(DataChannel* channel, dataChannels()) {
		CHECK_OBJECT_POINTER(channel);
		OVITO_ASSERT(channel->channelUsageCount() >= 1);

		// Create a new data channel that will replace the old one.
		DataChannel::SmartPtr newChannel = cloneHelper.cloneObject(channel, false);
		newChannel->resize(newAtomsCount);

		// Replace original channel with the filtered one.
		replaceDataChannel(channel, newChannel.get());

		oldToNewMap.push_back(qMakePair(channel, newChannel.get()));
	}

	// Transfer and filter per-atom data elements.
	QtConcurrent::blockingMap(oldToNewMap, DeleteAtomsKernel(mask));

	invalidate();
	return newAtomsCount;
}

/******************************************************************************
* Creates a copy of this object.
******************************************************************************/
RefTarget::SmartPtr AtomsObject::clone(bool deepCopy, CloneHelper& cloneHelper)
{
	// Let the base class create an instance of this class.
	AtomsObject::SmartPtr clone = static_object_cast<AtomsObject>(SceneObject::clone(deepCopy, cloneHelper));

	// Copy internal data.
	clone->numAtoms = this->numAtoms;

	OVITO_ASSERT(clone->dataChannels().size() == this->dataChannels().size());

	return clone;
}

/******************************************************************************
* Returns an array of color values that contains the individual color of each atom.
******************************************************************************/
QVector<Color> AtomsObject::getAtomColors(TimeTicks time, TimeInterval& validityInterval)
{
	QVector<Color> colorArray(atomsCount());

	DataChannel* colorChannel = getStandardDataChannel(DataChannel::ColorChannel);
	AtomTypeDataChannel* typeChannel = static_object_cast<AtomTypeDataChannel>(getStandardDataChannel(DataChannel::AtomTypeChannel));

	if(colorChannel && colorChannel->isVisible()) {
		const Vector3* colorIter = colorChannel->constDataVector3();
		for(QVector<Color>::iterator dest = colorArray.begin(); dest != colorArray.end(); ++dest)
			*dest = *colorIter++;
	}
	else if(typeChannel && typeChannel->isVisible()) {
		const int* typeIter = typeChannel->constDataInt();
		// Get the atom type colors.
		QVector<Color> typeColors(typeChannel->atomTypes().size(), Color(1,1,1));
		for(int i=0; i<typeChannel->atomTypes().size(); i++) {
			AtomType* atype = typeChannel->atomTypes()[i];
			if(atype && atype->colorController())
				atype->colorController()->getValue(time, typeColors[i], validityInterval);
		}
		for(QVector<Color>::iterator dest = colorArray.begin(); dest != colorArray.end(); ++dest) {
			*dest = typeColors[(*typeIter++) % typeColors.size()];
		}
	}
	else {
		colorArray.fill(Color(1,1,1));
	}
	return colorArray;
}

/******************************************************************************
* Returns an array of values that contains the individual radius of each atom.
******************************************************************************/
QVector<FloatType> AtomsObject::getAtomRadii(TimeTicks time, TimeInterval& validityInterval)
{
	DataChannel* radiusChannel = getStandardDataChannel(DataChannel::RadiusChannel);
	AtomTypeDataChannel* typeChannel = static_object_cast<AtomTypeDataChannel>(getStandardDataChannel(DataChannel::AtomTypeChannel));
	PositionDataChannel* posChannel = static_object_cast<PositionDataChannel>(getStandardDataChannel(DataChannel::PositionChannel));

	QVector<FloatType> radiusArray(atomsCount());
	FloatType globalScaling = 1;
	if(posChannel && posChannel->globalAtomRadiusScaleController()) {
		posChannel->globalAtomRadiusScaleController()->getValue(time, globalScaling, validityInterval);
		OVITO_ASSERT(globalScaling >= 0);
	}

	if(radiusChannel && radiusChannel->isVisible()) {
		const FloatType* radiusIter = radiusChannel->constDataFloat();
		for(QVector<FloatType>::iterator dest = radiusArray.begin(); dest != radiusArray.end(); ++dest)
			*dest = *radiusIter++;
	}
	else if(typeChannel && typeChannel->isVisible()) {
		const int* typeIter = typeChannel->constDataInt();
		// Get the atom type radii.
		QVector<FloatType> typeRadii(typeChannel->atomTypes().size(), globalScaling);
		for(int i=0; i<typeChannel->atomTypes().size(); i++) {
			AtomType* atype = typeChannel->atomTypes()[i];
			if(atype && atype->radiusController()) {
				atype->radiusController()->getValue(time, typeRadii[i], validityInterval);
				typeRadii[i] *= globalScaling;
			}
		}
		for(QVector<FloatType>::iterator dest = radiusArray.begin(); dest != radiusArray.end(); ++dest) {
			*dest = typeRadii[(*typeIter++) % typeRadii.size()];
		}
	}
	else {
		radiusArray.fill(globalScaling);
	}
	return radiusArray;
}

/******************************************************************************
* Performs a ray intersection calculation.
******************************************************************************/
bool AtomsObject::intersectRay(const Ray3& ray, TimeTicks time, ObjectNode* contextNode, FloatType& t, Vector3& normal)
{
	// Get the atomic positions and radii.
	DataChannel* posChannel = getStandardDataChannel(DataChannel::PositionChannel);
	if(!posChannel) return false;

	TimeInterval iv;
	QVector<FloatType> radii = getAtomRadii(time, iv);
	OVITO_ASSERT(posChannel->size() == radii.size());

	FloatType closestHitDistance = FLOATTYPE_MAX;

	// Hit test every atom.
	const Point3* p = posChannel->constDataPoint3();
	const FloatType* r = radii.constBegin();
	for(size_t index = 0; index < posChannel->size(); index++, ++p, ++r) {

		// Perform ray-sphere intersection test.
		Vector3 sphere_dir = *p - ray.base;
		FloatType b = DotProduct(ray.dir, sphere_dir);
		FloatType temp = DotProduct(sphere_dir, sphere_dir);
		FloatType disc = b*b + square(*r) - temp;

		// Only calculate the nearest intersection
		if(disc <= 0.0)
			continue; // Ray missed sphere entirely.

		// Calculate closest intersection.
		FloatType tnear = b - sqrt(disc);
		if(tnear <= 0.0)
			continue;

		if(tnear < closestHitDistance) {
			closestHitDistance = tnear;
			normal = Normalize(ray.point(tnear) - *p);
		}
	}

	if(closestHitDistance != FLOATTYPE_MAX) {
		t = closestHitDistance;
		return true;
	}
	else return false;
}

};	// End of namespace AtomViz
