# Copyright 2013-2016 Canonical Ltd.  This software is licensed under the
# GNU Affero General Public License version 3 (see the file LICENSE).

"""Test Zone objects."""

__all__ = []

from django.db.utils import IntegrityError
from maasserver.enum import NODE_TYPE
from maasserver.models.zone import (
    DEFAULT_ZONE_NAME,
    Zone,
)
from maasserver.testing.factory import factory
from maasserver.testing.testcase import MAASServerTestCase
from maasserver.utils.orm import reload_object
from maastesting.matchers import MockCalledOnce


class TestZoneManager(MAASServerTestCase):
    """Tests for `Zone` manager."""

    def test_get_default_zone_returns_default_zone(self):
        self.assertEqual(
            DEFAULT_ZONE_NAME, Zone.objects.get_default_zone().name)

    def test_get_default_zone_ignores_other_zones(self):
        factory.make_Zone()
        self.assertEqual(
            DEFAULT_ZONE_NAME, Zone.objects.get_default_zone().name)

    def test_get_default_zone_handles_exception(self):
        default_zone = Zone.objects.get_default_zone()
        func = self.patch(Zone.objects, "get_or_create")
        func.side_effect = IntegrityError(
            'duplicate key value violates unique constraint '
            '"maasserver_zone_name_key"')
        zone = Zone.objects.get_default_zone()
        self.assertThat(func, MockCalledOnce())
        self.assertEqual(default_zone.id, zone.id)


class TestZone(MAASServerTestCase):
    """Tests for :class:`Zone`."""

    def test_init(self):
        node1 = factory.make_Node()
        node2 = factory.make_Node()
        name = factory.make_name('name')
        description = factory.make_name('description')

        zone = Zone(name=name, description=description)
        zone.save()
        zone.node_set.add(node1)
        zone.node_set.add(node2)

        self.assertEqual(
            (
                set(zone.node_set.all()),
                zone.name,
                zone.description,
                node1.zone,
                node2.zone,
            ),
            (set([node1, node2]), name, description, zone, zone))

    def test_delete_deletes_zone(self):
        zone = factory.make_Zone()
        zone.delete()
        self.assertIsNone(reload_object(zone))

    def test_delete_severs_link_to_nodes(self):
        zone = factory.make_Zone()
        node = factory.make_Node(zone=zone)
        zone.delete()
        self.assertIsNone(reload_object(zone))
        node = reload_object(node)
        self.assertIsNotNone(node)
        self.assertEqual(Zone.objects.get_default_zone(), node.zone)

    def test_is_default_returns_True_for_default_zone(self):
        self.assertTrue(Zone.objects.get_default_zone().is_default())

    def test_is_default_returns_False_for_normal_zone(self):
        self.assertFalse(factory.make_Zone().is_default())

    def test_nodes_only_set(self):
        """zone.node_only_set has only type node."""
        zone = factory.make_Zone()
        node1 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        node2 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        node3 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        device1 = factory.make_Node(zone=zone, node_type=NODE_TYPE.DEVICE)
        device2 = factory.make_Node(zone=zone, node_type=NODE_TYPE.DEVICE)
        rack_controller = factory.make_Node(
            zone=zone, node_type=NODE_TYPE.RACK_CONTROLLER)
        self.assertEqual(zone.node_only_set.count(), 3)
        self.assertIn(node1, zone.node_only_set)
        self.assertIn(node2, zone.node_only_set)
        self.assertIn(node3, zone.node_only_set)
        self.assertNotIn(device1, zone.node_only_set)
        self.assertNotIn(device2, zone.node_only_set)
        self.assertNotIn(rack_controller, zone.node_only_set)

    def test_devices_only_set(self):
        """zone.devices_only_set has only type device."""
        zone = factory.make_Zone()
        node1 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        node2 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        node3 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        device1 = factory.make_Node(zone=zone, node_type=NODE_TYPE.DEVICE)
        device2 = factory.make_Node(zone=zone, node_type=NODE_TYPE.DEVICE)
        rack_controller = factory.make_Node(
            zone=zone, node_type=NODE_TYPE.RACK_CONTROLLER)
        self.assertEqual(zone.device_only_set.count(), 2)
        self.assertNotIn(node1, zone.device_only_set)
        self.assertNotIn(node2, zone.device_only_set)
        self.assertNotIn(node3, zone.device_only_set)
        self.assertIn(device1, zone.device_only_set)
        self.assertIn(device2, zone.device_only_set)
        self.assertNotIn(rack_controller, zone.node_only_set)

    def test_rack_controllers_only_set(self):
        """zone.rack_controllers_only_set has only type rack_controller."""
        zone = factory.make_Zone()
        node1 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        node2 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        node3 = factory.make_Node(zone=zone, node_type=NODE_TYPE.MACHINE)
        device1 = factory.make_Node(zone=zone, node_type=NODE_TYPE.DEVICE)
        device2 = factory.make_Node(zone=zone, node_type=NODE_TYPE.DEVICE)
        rack_controller = factory.make_Node(
            zone=zone, node_type=NODE_TYPE.RACK_CONTROLLER)
        self.assertEqual(zone.device_only_set.count(), 2)
        self.assertNotIn(node1, zone.rack_controller_only_set)
        self.assertNotIn(node2, zone.rack_controller_only_set)
        self.assertNotIn(node3, zone.rack_controller_only_set)
        self.assertNotIn(device1, zone.rack_controller_only_set)
        self.assertNotIn(device2, zone.rack_controller_only_set)
        self.assertIn(rack_controller, zone.rack_controller_only_set)
